golib/auth/jwt/tests/round-trip-go-jwt/round_trip_test.go
AJ ONeal eb65999a6c
refactor: rename PublicKey.Key to Pub, PrivateKey.privKey to Priv
- PublicKey.Key → PublicKey.Pub (CryptoPublicKey)
- PrivateKey.privKey → PrivateKey.Priv (crypto.Signer, now public)
- Update all internal usages in jwk.go, sign.go, verify.go
- Update all test usages in jwt_test.go, coverage_test.go, edge_test.go
- Update interop tests (round-trip-go-jose, round-trip-go-jwt, nuance)
- Update SKILL.md documentation

Breaking change: PublicKey.Key renamed to PublicKey.Pub
2026-03-18 18:58:49 -06:00

776 lines
22 KiB
Go

// Copyright 2026 AJ ONeal <aj@therootcompany.com> (https://therootcompany.com)
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at https://mozilla.org/MPL/2.0/.
//
// SPDX-License-Identifier: MPL-2.0
// Package roundtrip_test verifies interoperability between this library and
// github.com/golang-jwt/jwt/v5. It lives in a separate module (tests/go.mod)
// so that the golang-jwt dependency does not leak into the main module graph.
//
// Tests cover:
// - Our sign + their verify (Ed25519, EC P-256, P-384, P-521, RSA)
// - Their sign + our verify (Ed25519, EC P-256, P-384, P-521, RSA)
// - Known/fixed keys: deterministic key material for reproducible tests
// - Stress tests: 1,000 keys per algorithm to catch ASN.1/padding edge cases
// - JWK key round-trip: marshal/unmarshal private and public keys, then
// confirm the recovered keys interoperate correctly.
package roundtrip_test
import (
"crypto"
"crypto/ecdsa"
"crypto/ed25519"
"crypto/elliptic"
"crypto/rand"
"crypto/rsa"
"crypto/sha256"
"encoding/json"
"flag"
"fmt"
"io"
"testing"
"time"
gjwt "github.com/golang-jwt/jwt/v5"
"github.com/therootcompany/golib/auth/jwt"
)
var longTests = flag.Bool("long", false, "run extended stress tests (100 RSA iterations instead of 10)")
// --- helpers ---
// testClaims returns a fresh set of claims for a test iteration.
func testClaims(sub string) *jwt.TokenClaims {
now := time.Now()
return &jwt.TokenClaims{
Iss: "https://example.com",
Sub: sub,
Exp: now.Add(time.Hour).Unix(),
IAt: now.Unix(),
}
}
// hashReader produces deterministic bytes from a SHA-256 hash chain.
// Not cryptographically secure - used only for reproducible test key generation.
type hashReader struct {
state [32]byte
pos int
}
func deterministicRand(seed string) io.Reader {
s := sha256.Sum256([]byte(seed))
return &hashReader{state: s}
}
func (r *hashReader) Read(p []byte) (int, error) {
n := 0
for n < len(p) {
if r.pos >= len(r.state) {
r.state = sha256.Sum256(r.state[:])
r.pos = 0
}
copied := copy(p[n:], r.state[r.pos:])
n += copied
r.pos += copied
}
return n, nil
}
// mustPK wraps jwt.FromPrivateKey and fails the test on error.
func mustPK(t testing.TB, signer crypto.Signer, kid string) *jwt.PrivateKey {
t.Helper()
pk, err := jwt.FromPrivateKey(signer, kid)
if err != nil {
t.Fatal(err)
}
return pk
}
// assertOurSignTheirVerify signs with our library and verifies with golang-jwt.
func assertOurSignTheirVerify(t *testing.T, pk *jwt.PrivateKey, gjwtMethod gjwt.SigningMethod, gjwtPub any, sub string) {
t.Helper()
signer, err := jwt.NewSigner([]*jwt.PrivateKey{pk})
if err != nil {
t.Fatalf("NewSigner: %v", err)
}
claims := testClaims(sub)
tokenStr, err := signer.SignToString(claims)
if err != nil {
t.Fatalf("SignToString: %v", err)
}
parsed, err := gjwt.ParseWithClaims(tokenStr, &gjwt.RegisteredClaims{}, func(tok *gjwt.Token) (any, error) {
if tok.Method.Alg() != gjwtMethod.Alg() {
return nil, fmt.Errorf("unexpected alg: got %q, want %q", tok.Method.Alg(), gjwtMethod.Alg())
}
return gjwtPub, nil
})
if err != nil {
t.Fatalf("golang-jwt verify failed: %v", err)
}
rc, ok := parsed.Claims.(*gjwt.RegisteredClaims)
if !ok || !parsed.Valid {
t.Fatal("token invalid or claims unreadable")
}
if rc.Subject != sub {
t.Errorf("sub: got %q, want %q", rc.Subject, sub)
}
if rc.Issuer != claims.Iss {
t.Errorf("iss: got %q, want %q", rc.Issuer, claims.Iss)
}
}
// assertTheirSignOurVerify signs with golang-jwt and verifies with our library.
func assertTheirSignOurVerify(t *testing.T, gjwtMethod gjwt.SigningMethod, gjwtPriv any, kid string, ourPub jwt.PublicKey, sub string) {
t.Helper()
now := time.Now()
gClaims := gjwt.RegisteredClaims{
Issuer: "https://example.com",
Subject: sub,
ExpiresAt: gjwt.NewNumericDate(now.Add(time.Hour)),
IssuedAt: gjwt.NewNumericDate(now),
}
tok := gjwt.NewWithClaims(gjwtMethod, gClaims)
tok.Header["kid"] = kid
tokenStr, err := tok.SignedString(gjwtPriv)
if err != nil {
t.Fatalf("golang-jwt sign: %v", err)
}
verifier, _ := jwt.NewVerifier([]jwt.PublicKey{ourPub})
jws, err := verifier.VerifyJWT(tokenStr)
if err != nil {
t.Fatalf("our verify failed: %v", err)
}
var decoded jwt.TokenClaims
if err := jws.UnmarshalClaims(&decoded); err != nil {
t.Fatalf("UnmarshalClaims: %v", err)
}
if decoded.Sub != sub {
t.Errorf("sub: got %q, want %q", decoded.Sub, sub)
}
if decoded.Iss != gClaims.Issuer {
t.Errorf("iss: got %q, want %q", decoded.Iss, gClaims.Issuer)
}
}
// stressIteration tests one key in both directions: our sign + their verify,
// then their sign + our verify.
func stressIteration(t *testing.T, i int, pk *jwt.PrivateKey, pub jwt.PublicKey, gjwtMethod gjwt.SigningMethod, gjwtPriv any, gjwtPub any) {
t.Helper()
sub := fmt.Sprintf("stress-%d", i)
// Our sign, their verify.
signer, err := jwt.NewSigner([]*jwt.PrivateKey{pk})
if err != nil {
t.Fatalf("iter %d: NewSigner: %v", i, err)
}
tokenStr, err := signer.SignToString(testClaims(sub))
if err != nil {
t.Fatalf("iter %d: SignToString: %v", i, err)
}
_, err = gjwt.ParseWithClaims(tokenStr, &gjwt.RegisteredClaims{}, func(tok *gjwt.Token) (any, error) {
return gjwtPub, nil
})
if err != nil {
t.Fatalf("iter %d: golang-jwt verify: %v", i, err)
}
// Their sign, our verify.
gClaims := gjwt.RegisteredClaims{
Subject: sub,
ExpiresAt: gjwt.NewNumericDate(time.Now().Add(time.Hour)),
}
tok := gjwt.NewWithClaims(gjwtMethod, gClaims)
tok.Header["kid"] = pk.KID
tokenStr, err = tok.SignedString(gjwtPriv)
if err != nil {
t.Fatalf("iter %d: golang-jwt sign: %v", i, err)
}
verifier, _ := jwt.NewVerifier([]jwt.PublicKey{pub})
if _, err := verifier.VerifyJWT(tokenStr); err != nil {
t.Fatalf("iter %d: our verify: %v", i, err)
}
}
// --- Our sign, their verify (all algorithms) ---
func TestOurSignTheirVerify_EdDSA(t *testing.T) {
_, priv, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
t.Fatal(err)
}
pub := priv.Public().(ed25519.PublicKey)
assertOurSignTheirVerify(t,
mustPK(t, priv, "k1"),
gjwt.SigningMethodEdDSA, pub, "user-eddsa")
}
func TestOurSignTheirVerify_ES256(t *testing.T) {
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
t.Fatal(err)
}
assertOurSignTheirVerify(t,
mustPK(t, priv, "k1"),
gjwt.SigningMethodES256, &priv.PublicKey, "user-es256")
}
func TestOurSignTheirVerify_ES384(t *testing.T) {
priv, err := ecdsa.GenerateKey(elliptic.P384(), rand.Reader)
if err != nil {
t.Fatal(err)
}
assertOurSignTheirVerify(t,
mustPK(t, priv, "k1"),
gjwt.SigningMethodES384, &priv.PublicKey, "user-es384")
}
func TestOurSignTheirVerify_ES512(t *testing.T) {
priv, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader)
if err != nil {
t.Fatal(err)
}
assertOurSignTheirVerify(t,
mustPK(t, priv, "k1"),
gjwt.SigningMethodES512, &priv.PublicKey, "user-es512")
}
func TestOurSignTheirVerify_RS256(t *testing.T) {
priv, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatal(err)
}
assertOurSignTheirVerify(t,
mustPK(t, priv, "k1"),
gjwt.SigningMethodRS256, &priv.PublicKey, "user-rs256")
}
// --- Their sign, our verify (all algorithms) ---
func TestTheirSignOurVerify_EdDSA(t *testing.T) {
_, priv, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
t.Fatal(err)
}
pub := priv.Public().(ed25519.PublicKey)
assertTheirSignOurVerify(t,
gjwt.SigningMethodEdDSA, priv, "k1",
jwt.PublicKey{Pub: pub, KID: "k1"}, "user-eddsa")
}
func TestTheirSignOurVerify_ES256(t *testing.T) {
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
t.Fatal(err)
}
assertTheirSignOurVerify(t,
gjwt.SigningMethodES256, priv, "k1",
jwt.PublicKey{Pub: &priv.PublicKey, KID: "k1"}, "user-es256")
}
func TestTheirSignOurVerify_ES384(t *testing.T) {
priv, err := ecdsa.GenerateKey(elliptic.P384(), rand.Reader)
if err != nil {
t.Fatal(err)
}
assertTheirSignOurVerify(t,
gjwt.SigningMethodES384, priv, "k1",
jwt.PublicKey{Pub: &priv.PublicKey, KID: "k1"}, "user-es384")
}
func TestTheirSignOurVerify_ES512(t *testing.T) {
priv, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader)
if err != nil {
t.Fatal(err)
}
assertTheirSignOurVerify(t,
gjwt.SigningMethodES512, priv, "k1",
jwt.PublicKey{Pub: &priv.PublicKey, KID: "k1"}, "user-es512")
}
func TestTheirSignOurVerify_RS256(t *testing.T) {
priv, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatal(err)
}
assertTheirSignOurVerify(t,
gjwt.SigningMethodRS256, priv, "k1",
jwt.PublicKey{Pub: &priv.PublicKey, KID: "k1"}, "user-rs256")
}
// --- Known key tests ---
//
// Each algorithm uses deterministic key material so failures are reproducible
// across runs. Ed25519 uses NewKeyFromSeed; EC and RSA use a SHA-256 hash
// chain seeded from a fixed string.
func TestKnownKeys(t *testing.T) {
t.Run("EdDSA", func(t *testing.T) {
seed := make([]byte, ed25519.SeedSize)
for i := range seed {
seed[i] = byte(i)
}
priv := ed25519.NewKeyFromSeed(seed)
pub := priv.Public().(ed25519.PublicKey)
kid := "known-ed"
pk := mustPK(t, priv, kid)
pubKey := jwt.PublicKey{Pub: pub, KID: kid}
assertOurSignTheirVerify(t, pk, gjwt.SigningMethodEdDSA, pub, "known-ed-ours")
assertTheirSignOurVerify(t, gjwt.SigningMethodEdDSA, priv, kid, pubKey, "known-ed-theirs")
})
t.Run("ES256", func(t *testing.T) {
priv, err := ecdsa.GenerateKey(elliptic.P256(), deterministicRand("known-es256"))
if err != nil {
t.Fatal(err)
}
kid := "known-es256"
pk := mustPK(t, priv, kid)
pubKey := jwt.PublicKey{Pub: &priv.PublicKey, KID: kid}
assertOurSignTheirVerify(t, pk, gjwt.SigningMethodES256, &priv.PublicKey, "known-es256-ours")
assertTheirSignOurVerify(t, gjwt.SigningMethodES256, priv, kid, pubKey, "known-es256-theirs")
})
t.Run("ES384", func(t *testing.T) {
priv, err := ecdsa.GenerateKey(elliptic.P384(), deterministicRand("known-es384"))
if err != nil {
t.Fatal(err)
}
kid := "known-es384"
pk := mustPK(t, priv, kid)
pubKey := jwt.PublicKey{Pub: &priv.PublicKey, KID: kid}
assertOurSignTheirVerify(t, pk, gjwt.SigningMethodES384, &priv.PublicKey, "known-es384-ours")
assertTheirSignOurVerify(t, gjwt.SigningMethodES384, priv, kid, pubKey, "known-es384-theirs")
})
t.Run("ES512", func(t *testing.T) {
priv, err := ecdsa.GenerateKey(elliptic.P521(), deterministicRand("known-es512"))
if err != nil {
t.Fatal(err)
}
kid := "known-es512"
pk := mustPK(t, priv, kid)
pubKey := jwt.PublicKey{Pub: &priv.PublicKey, KID: kid}
assertOurSignTheirVerify(t, pk, gjwt.SigningMethodES512, &priv.PublicKey, "known-es512-ours")
assertTheirSignOurVerify(t, gjwt.SigningMethodES512, priv, kid, pubKey, "known-es512-theirs")
})
t.Run("RS256", func(t *testing.T) {
priv, err := rsa.GenerateKey(deterministicRand("known-rs256"), 2048)
if err != nil {
t.Fatal(err)
}
kid := "known-rs256"
pk := mustPK(t, priv, kid)
pubKey := jwt.PublicKey{Pub: &priv.PublicKey, KID: kid}
assertOurSignTheirVerify(t, pk, gjwt.SigningMethodRS256, &priv.PublicKey, "known-rs256-ours")
assertTheirSignOurVerify(t, gjwt.SigningMethodRS256, priv, kid, pubKey, "known-rs256-theirs")
})
}
// --- Stress tests ---
//
// Each subtest generates 1,000 random keys and signs+verifies in both
// directions per key. This catches edge cases in ASN.1 DER-to-raw signature
// conversion (ECDSA r/s values that are shorter than the field size and
// need left-padding) and any key-dependent encoding quirks.
//
// RSA keygen is inherently slow (~10ms per 2048-bit key); RSA defaults to
// 10 iterations. Use -long to run 100.
func TestStress(t *testing.T) {
t.Run("EdDSA", func(t *testing.T) {
t.Parallel()
for i := range 1000 {
_, priv, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
t.Fatalf("iter %d: keygen: %v", i, err)
}
pub := priv.Public().(ed25519.PublicKey)
kid := fmt.Sprintf("s%d", i)
stressIteration(t, i,
mustPK(t, priv, kid),
jwt.PublicKey{Pub: pub, KID: kid},
gjwt.SigningMethodEdDSA, priv, pub)
}
})
t.Run("ES256", func(t *testing.T) {
t.Parallel()
for i := range 1000 {
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
t.Fatalf("iter %d: keygen: %v", i, err)
}
kid := fmt.Sprintf("s%d", i)
stressIteration(t, i,
mustPK(t, priv, kid),
jwt.PublicKey{Pub: &priv.PublicKey, KID: kid},
gjwt.SigningMethodES256, priv, &priv.PublicKey)
}
})
t.Run("ES384", func(t *testing.T) {
t.Parallel()
for i := range 1000 {
priv, err := ecdsa.GenerateKey(elliptic.P384(), rand.Reader)
if err != nil {
t.Fatalf("iter %d: keygen: %v", i, err)
}
kid := fmt.Sprintf("s%d", i)
stressIteration(t, i,
mustPK(t, priv, kid),
jwt.PublicKey{Pub: &priv.PublicKey, KID: kid},
gjwt.SigningMethodES384, priv, &priv.PublicKey)
}
})
t.Run("ES512", func(t *testing.T) {
t.Parallel()
for i := range 1000 {
priv, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader)
if err != nil {
t.Fatalf("iter %d: keygen: %v", i, err)
}
kid := fmt.Sprintf("s%d", i)
stressIteration(t, i,
mustPK(t, priv, kid),
jwt.PublicKey{Pub: &priv.PublicKey, KID: kid},
gjwt.SigningMethodES512, priv, &priv.PublicKey)
}
})
t.Run("RS256", func(t *testing.T) {
t.Parallel()
n := 10
if *longTests {
n = 100
}
for i := range n {
priv, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatalf("iter %d: keygen: %v", i, err)
}
kid := fmt.Sprintf("s%d", i)
stressIteration(t, i,
mustPK(t, priv, kid),
jwt.PublicKey{Pub: &priv.PublicKey, KID: kid},
gjwt.SigningMethodRS256, priv, &priv.PublicKey)
}
})
}
// --- JWK private key round-trip ---
//
// Marshal a private key to JWK JSON, unmarshal it back, and confirm the
// recovered key produces tokens verifiable by both the original public key
// and golang-jwt.
func TestJWKPrivateKeyRoundTrip(t *testing.T) {
t.Run("Ed25519", func(t *testing.T) {
original, err := jwt.NewPrivateKey()
if err != nil {
t.Fatal(err)
}
pub, err := original.PublicKey()
if err != nil {
t.Fatal(err)
}
assertPrivateKeyRoundTrip(t, original,
gjwt.SigningMethodEdDSA, pub.Pub.(ed25519.PublicKey))
})
t.Run("EC_P256", func(t *testing.T) {
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
t.Fatal(err)
}
original := mustPK(t, priv, "ec256-rt")
assertPrivateKeyRoundTrip(t, original,
gjwt.SigningMethodES256, &priv.PublicKey)
})
t.Run("EC_P384", func(t *testing.T) {
priv, err := ecdsa.GenerateKey(elliptic.P384(), rand.Reader)
if err != nil {
t.Fatal(err)
}
original := mustPK(t, priv, "ec384-rt")
assertPrivateKeyRoundTrip(t, original,
gjwt.SigningMethodES384, &priv.PublicKey)
})
t.Run("EC_P521", func(t *testing.T) {
priv, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader)
if err != nil {
t.Fatal(err)
}
original := mustPK(t, priv, "ec521-rt")
assertPrivateKeyRoundTrip(t, original,
gjwt.SigningMethodES512, &priv.PublicKey)
})
t.Run("RSA", func(t *testing.T) {
priv, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatal(err)
}
original := mustPK(t, priv, "rsa-rt")
assertPrivateKeyRoundTrip(t, original,
gjwt.SigningMethodRS256, &priv.PublicKey)
})
}
func assertPrivateKeyRoundTrip(t *testing.T, original *jwt.PrivateKey, gjwtMethod gjwt.SigningMethod, gjwtPub any) {
t.Helper()
// Marshal to JSON.
data, err := json.Marshal(original)
if err != nil {
t.Fatalf("marshal: %v", err)
}
// Unmarshal back.
var recovered jwt.PrivateKey
if err := json.Unmarshal(data, &recovered); err != nil {
t.Fatalf("unmarshal: %v", err)
}
if recovered.KID != original.KID {
t.Errorf("KID: got %q, want %q", recovered.KID, original.KID)
}
claims := testClaims("pk-roundtrip")
// Sign with recovered key, verify with original pubkey.
signer, err := jwt.NewSigner([]*jwt.PrivateKey{&recovered})
if err != nil {
t.Fatal(err)
}
tokenStr, err := signer.SignToString(claims)
if err != nil {
t.Fatal(err)
}
origPub, _ := original.PublicKey()
verifier, _ := jwt.NewVerifier([]jwt.PublicKey{*origPub})
if _, err := verifier.VerifyJWT(tokenStr); err != nil {
t.Errorf("verify with original pubkey: %v", err)
}
// Sign with original key, verify with recovered pubkey.
origSigner, err := jwt.NewSigner([]*jwt.PrivateKey{original})
if err != nil {
t.Fatal(err)
}
tokenStr2, err := origSigner.SignToString(claims)
if err != nil {
t.Fatal(err)
}
recPub, _ := recovered.PublicKey()
verifier2, _ := jwt.NewVerifier([]jwt.PublicKey{*recPub})
if _, err := verifier2.VerifyJWT(tokenStr2); err != nil {
t.Errorf("verify with recovered pubkey: %v", err)
}
// Cross-verify with golang-jwt.
_, err = gjwt.ParseWithClaims(tokenStr, &gjwt.RegisteredClaims{}, func(tok *gjwt.Token) (any, error) {
return gjwtPub, nil
})
if err != nil {
t.Errorf("golang-jwt cross-verify: %v", err)
}
}
// --- JWK public key round-trip ---
//
// Marshal a public key to JWK JSON, unmarshal it back, and confirm the
// round-tripped key verifies tokens signed by the original private key.
func TestJWKPublicKeyRoundTrip(t *testing.T) {
t.Run("Ed25519", func(t *testing.T) {
_, priv, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
t.Fatal(err)
}
pub := priv.Public().(ed25519.PublicKey)
signer, err := jwt.NewSigner([]*jwt.PrivateKey{mustPK(t, priv, "ed-pub-rt")})
if err != nil {
t.Fatal(err)
}
assertPublicKeyRoundTrip(t,
jwt.PublicKey{Pub: pub, KID: "ed-pub-rt"},
signer, pub)
})
t.Run("EC_P256", func(t *testing.T) {
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
t.Fatal(err)
}
signer, err := jwt.NewSigner([]*jwt.PrivateKey{mustPK(t, priv, "ec256-pub-rt")})
if err != nil {
t.Fatal(err)
}
assertPublicKeyRoundTrip(t,
jwt.PublicKey{Pub: &priv.PublicKey, KID: "ec256-pub-rt"},
signer, &priv.PublicKey)
})
t.Run("EC_P384", func(t *testing.T) {
priv, err := ecdsa.GenerateKey(elliptic.P384(), rand.Reader)
if err != nil {
t.Fatal(err)
}
signer, err := jwt.NewSigner([]*jwt.PrivateKey{mustPK(t, priv, "ec384-pub-rt")})
if err != nil {
t.Fatal(err)
}
assertPublicKeyRoundTrip(t,
jwt.PublicKey{Pub: &priv.PublicKey, KID: "ec384-pub-rt"},
signer, &priv.PublicKey)
})
t.Run("EC_P521", func(t *testing.T) {
priv, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader)
if err != nil {
t.Fatal(err)
}
signer, err := jwt.NewSigner([]*jwt.PrivateKey{mustPK(t, priv, "ec521-pub-rt")})
if err != nil {
t.Fatal(err)
}
assertPublicKeyRoundTrip(t,
jwt.PublicKey{Pub: &priv.PublicKey, KID: "ec521-pub-rt"},
signer, &priv.PublicKey)
})
t.Run("RSA", func(t *testing.T) {
priv, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatal(err)
}
signer, err := jwt.NewSigner([]*jwt.PrivateKey{mustPK(t, priv, "rsa-pub-rt")})
if err != nil {
t.Fatal(err)
}
assertPublicKeyRoundTrip(t,
jwt.PublicKey{Pub: &priv.PublicKey, KID: "rsa-pub-rt"},
signer, &priv.PublicKey)
})
}
func assertPublicKeyRoundTrip(t *testing.T, origPub jwt.PublicKey, signer *jwt.Signer, gjwtPub any) {
t.Helper()
// Marshal to JSON.
data, err := json.Marshal(origPub)
if err != nil {
t.Fatalf("marshal: %v", err)
}
// Unmarshal back.
var recovered jwt.PublicKey
if err := json.Unmarshal(data, &recovered); err != nil {
t.Fatalf("unmarshal: %v", err)
}
if recovered.KID != origPub.KID {
t.Errorf("KID: got %q, want %q", recovered.KID, origPub.KID)
}
// Sign and verify with the round-tripped key.
tokenStr, err := signer.SignToString(testClaims("pub-roundtrip"))
if err != nil {
t.Fatal(err)
}
verifier, _ := jwt.NewVerifier([]jwt.PublicKey{recovered})
if _, err := verifier.VerifyJWT(tokenStr); err != nil {
t.Errorf("verify with round-tripped pubkey: %v", err)
}
// Cross-verify with golang-jwt.
_, err = gjwt.ParseWithClaims(tokenStr, &gjwt.RegisteredClaims{}, func(tok *gjwt.Token) (any, error) {
return gjwtPub, nil
})
if err != nil {
t.Errorf("golang-jwt cross-verify: %v", err)
}
}
// --- JWKS round-trip ---
// TestJWKSRoundTrip marshals a full JWKS document containing all supported
// key types and verifies that tokens signed with each key are verifiable
// after unmarshal.
func TestJWKSRoundTrip(t *testing.T) {
edKey, err := jwt.NewPrivateKey()
if err != nil {
t.Fatal(err)
}
ec256, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
t.Fatal(err)
}
ec384, err := ecdsa.GenerateKey(elliptic.P384(), rand.Reader)
if err != nil {
t.Fatal(err)
}
ec521, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader)
if err != nil {
t.Fatal(err)
}
rsaKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatal(err)
}
keys := []*jwt.PrivateKey{
edKey,
mustPK(t, ec256, "ec256"),
mustPK(t, ec384, "ec384"),
mustPK(t, ec521, "ec521"),
mustPK(t, rsaKey, "rsa"),
}
signer, err := jwt.NewSigner(keys)
if err != nil {
t.Fatal(err)
}
// Serialize the JWKS (public keys only).
jwksData, err := json.Marshal(&signer)
if err != nil {
t.Fatalf("marshal JWKS: %v", err)
}
// Parse it back.
var jwks jwt.WellKnownJWKs
if err := json.Unmarshal(jwksData, &jwks); err != nil {
t.Fatalf("unmarshal JWKS: %v", err)
}
if len(jwks.Keys) != 5 {
t.Fatalf("expected 5 keys, got %d", len(jwks.Keys))
}
verifier, _ := jwt.NewVerifier(jwks.Keys)
claims := testClaims("jwks-round-trip")
// Sign with each key (round-robin) and verify all.
for i := range len(keys) {
tokenStr, err := signer.SignToString(claims)
if err != nil {
t.Fatalf("sign[%d]: %v", i, err)
}
if _, err := verifier.VerifyJWT(tokenStr); err != nil {
t.Errorf("verify[%d] after JWKS round-trip: %v", i, err)
}
}
}