Skip to content

Commit 27590f4

Browse files
authored
feat: improve token generation commands (#4226)
2 parents ef9b8a4 + 038cb98 commit 27590f4

File tree

7 files changed

+269
-104
lines changed

7 files changed

+269
-104
lines changed

cmd/gen.go

Lines changed: 39 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
package cmd
22

33
import (
4+
"encoding/json"
45
"os"
56
"os/signal"
7+
"strings"
68
"time"
79

810
env "github.com/Netflix/go-env"
911
"github.com/go-errors/errors"
12+
"github.com/go-viper/mapstructure/v2"
1013
"github.com/golang-jwt/jwt/v5"
1114
"github.com/spf13/afero"
1215
"github.com/spf13/cobra"
@@ -124,17 +127,18 @@ Supported algorithms:
124127
claims config.CustomClaims
125128
expiry time.Time
126129
validFor time.Duration
130+
payload string
127131

128132
genJWTCmd = &cobra.Command{
129133
Use: "bearer-jwt",
130134
Short: "Generate a Bearer Auth JWT for accessing Data API",
131135
Args: cobra.NoArgs,
132136
RunE: func(cmd *cobra.Command, args []string) error {
133-
if expiry.IsZero() {
134-
expiry = time.Now().Add(validFor)
137+
custom := jwt.MapClaims{}
138+
if err := parseClaims(custom); err != nil {
139+
return err
135140
}
136-
claims.ExpiresAt = jwt.NewNumericDate(expiry)
137-
return bearerjwt.Run(cmd.Context(), claims, os.Stdout, afero.NewOsFs())
141+
return bearerjwt.Run(cmd.Context(), custom, os.Stdout, afero.NewOsFs())
138142
},
139143
}
140144
)
@@ -166,12 +170,41 @@ func init() {
166170
genCmd.AddCommand(genSigningKeyCmd)
167171
tokenFlags := genJWTCmd.Flags()
168172
tokenFlags.StringVar(&claims.Role, "role", "", "Postgres role to use.")
173+
cobra.CheckErr(genJWTCmd.MarkFlagRequired("role"))
169174
tokenFlags.StringVar(&claims.Subject, "sub", "", "User ID to impersonate.")
170175
genJWTCmd.Flag("sub").DefValue = "anonymous"
171176
tokenFlags.TimeVar(&expiry, "exp", time.Time{}, []string{time.RFC3339}, "Expiry timestamp for this token.")
172177
tokenFlags.DurationVar(&validFor, "valid-for", time.Minute*30, "Validity duration for this token.")
173-
genJWTCmd.MarkFlagsMutuallyExclusive("exp", "valid-for")
174-
cobra.CheckErr(genJWTCmd.MarkFlagRequired("role"))
178+
tokenFlags.StringVar(&payload, "payload", "{}", "Custom claims in JSON format.")
175179
genCmd.AddCommand(genJWTCmd)
176180
rootCmd.AddCommand(genCmd)
177181
}
182+
183+
func parseClaims(custom jwt.MapClaims) error {
184+
// Initialise default claims
185+
now := time.Now()
186+
if expiry.IsZero() {
187+
expiry = now.Add(validFor)
188+
} else {
189+
now = expiry.Add(-validFor)
190+
}
191+
claims.IssuedAt = jwt.NewNumericDate(now)
192+
claims.ExpiresAt = jwt.NewNumericDate(expiry)
193+
// Set is_anonymous = true for authenticated role without explicit user ID
194+
if strings.EqualFold(claims.Role, "authenticated") && len(claims.Subject) == 0 {
195+
claims.IsAnon = true
196+
}
197+
// Override with custom claims
198+
if dec, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{
199+
TagName: "json",
200+
Result: &custom,
201+
}); err != nil {
202+
return errors.Errorf("failed to init decoder: %w", err)
203+
} else if err := dec.Decode(claims); err != nil {
204+
return errors.Errorf("failed to decode claims: %w", err)
205+
}
206+
if err := json.Unmarshal([]byte(payload), &custom); err != nil {
207+
return errors.Errorf("failed to parse payload: %w", err)
208+
}
209+
return nil
210+
}

internal/gen/bearerjwt/bearerjwt.go

Lines changed: 56 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,41 +2,80 @@ package bearerjwt
22

33
import (
44
"context"
5+
"encoding/json"
56
"fmt"
67
"io"
78
"os"
89
"strings"
910

1011
"github.com/go-errors/errors"
12+
"github.com/golang-jwt/jwt/v5"
1113
"github.com/spf13/afero"
1214
"github.com/supabase/cli/internal/utils"
1315
"github.com/supabase/cli/internal/utils/flags"
1416
"github.com/supabase/cli/pkg/config"
1517
)
1618

17-
func Run(ctx context.Context, claims config.CustomClaims, w io.Writer, fsys afero.Fs) error {
19+
func Run(ctx context.Context, claims jwt.Claims, w io.Writer, fsys afero.Fs) error {
1820
if err := flags.LoadConfig(fsys); err != nil {
1921
return err
2022
}
21-
// Set is_anonymous = true for authenticated role without explicit user ID
22-
if strings.EqualFold(claims.Role, "authenticated") && len(claims.Subject) == 0 {
23-
claims.IsAnon = true
24-
}
25-
// Use the first signing key that passes validation
26-
for _, k := range utils.Config.Auth.SigningKeys {
27-
fmt.Fprintln(os.Stderr, "Using signing key ID:", k.KeyID.String())
28-
if token, err := config.GenerateAsymmetricJWT(k, claims); err != nil {
29-
fmt.Fprintln(os.Stderr, err)
30-
} else {
31-
fmt.Fprintln(w, token)
32-
return nil
33-
}
23+
key, err := getSigningKey(ctx)
24+
if err != nil {
25+
return err
3426
}
35-
fmt.Fprintln(os.Stderr, "Using legacy JWT secret...")
36-
token, err := claims.NewToken().SignedString([]byte(utils.Config.Auth.JwtSecret.Value))
27+
token, err := config.GenerateAsymmetricJWT(*key, claims)
3728
if err != nil {
38-
return errors.Errorf("failed to generate auth token: %w", err)
29+
return err
3930
}
4031
fmt.Fprintln(w, token)
4132
return nil
4233
}
34+
35+
func getSigningKey(ctx context.Context) (*config.JWK, error) {
36+
console := utils.NewConsole()
37+
if len(utils.Config.Auth.SigningKeysPath) == 0 {
38+
title := "Enter your signing key in JWK format: "
39+
kid, err := console.PromptText(ctx, title)
40+
if err != nil {
41+
return nil, err
42+
}
43+
key := config.JWK{}
44+
if err := json.Unmarshal([]byte(kid), &key); err != nil {
45+
return nil, errors.Errorf("failed to parse JWK: %w", err)
46+
}
47+
return &key, nil
48+
}
49+
// Allow manual kid entry on CI
50+
if !console.IsTTY {
51+
title := "Enter the kid of your signing key (or leave blank to use the first one): "
52+
kid, err := console.PromptText(ctx, title)
53+
if err != nil {
54+
return nil, err
55+
}
56+
for i, k := range utils.Config.Auth.SigningKeys {
57+
if k.KeyID == kid {
58+
return &utils.Config.Auth.SigningKeys[i], nil
59+
}
60+
}
61+
if len(kid) == 0 && len(utils.Config.Auth.SigningKeys) > 0 {
62+
return &utils.Config.Auth.SigningKeys[0], nil
63+
}
64+
return nil, errors.Errorf("signing key not found: %s", kid)
65+
}
66+
// Let user choose from a list of signing keys
67+
items := make([]utils.PromptItem, len(utils.Config.Auth.SigningKeys))
68+
for i, k := range utils.Config.Auth.SigningKeys {
69+
items[i] = utils.PromptItem{
70+
Index: i,
71+
Summary: k.KeyID,
72+
Details: fmt.Sprintf("%s (%s)", k.Algorithm, strings.Join(k.KeyOps, ",")),
73+
}
74+
}
75+
choice, err := utils.PromptChoice(ctx, "Select a signing key:", items)
76+
if err != nil {
77+
return nil, err
78+
}
79+
fmt.Fprintln(os.Stderr, "Selected key ID:", choice.Summary)
80+
return &utils.Config.Auth.SigningKeys[choice.Index], nil
81+
}

internal/gen/bearerjwt/bearerjwt_test.go

Lines changed: 121 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -5,41 +5,56 @@ import (
55
"context"
66
"crypto/ecdsa"
77
"crypto/elliptic"
8+
"crypto/rsa"
89
_ "embed"
910
"encoding/json"
11+
"io"
1012
"testing"
1113

1214
"github.com/golang-jwt/jwt/v5"
13-
"github.com/google/uuid"
1415
"github.com/spf13/afero"
1516
"github.com/stretchr/testify/assert"
1617
"github.com/stretchr/testify/require"
1718
"github.com/supabase/cli/internal/gen/signingkeys"
19+
"github.com/supabase/cli/internal/testing/fstest"
1820
"github.com/supabase/cli/internal/utils"
1921
"github.com/supabase/cli/pkg/config"
2022
)
2123

2224
func TestGenerateToken(t *testing.T) {
25+
// Setup private key - ECDSA
26+
privateKeyECDSA, err := signingkeys.GeneratePrivateKey(config.AlgES256)
27+
require.NoError(t, err)
28+
// Setup public key for validation
29+
publicKeyECDSA := ecdsa.PublicKey{Curve: elliptic.P256()}
30+
publicKeyECDSA.X, err = config.NewBigIntFromBase64(privateKeyECDSA.X)
31+
require.NoError(t, err)
32+
publicKeyECDSA.Y, err = config.NewBigIntFromBase64(privateKeyECDSA.Y)
33+
require.NoError(t, err)
34+
35+
// Setup private key - RSA
36+
privateKeyRSA, err := signingkeys.GeneratePrivateKey(config.AlgRS256)
37+
require.NoError(t, err)
38+
// Setup public key for validation
39+
publicKeyRSA := rsa.PublicKey{}
40+
publicKeyRSA.N, err = config.NewBigIntFromBase64(privateKeyRSA.Modulus)
41+
require.NoError(t, err)
42+
bigE, err := config.NewBigIntFromBase64(privateKeyRSA.Exponent)
43+
require.NoError(t, err)
44+
publicKeyRSA.E = int(bigE.Int64())
45+
2346
t.Run("mints custom JWT", func(t *testing.T) {
2447
claims := config.CustomClaims{
25-
Role: "authenticated",
48+
IsAnon: true,
49+
Role: "authenticated",
2650
}
27-
// Setup private key
28-
privateKey, err := signingkeys.GeneratePrivateKey(config.AlgES256)
29-
require.NoError(t, err)
30-
// Setup public key for validation
31-
publicKey := ecdsa.PublicKey{Curve: elliptic.P256()}
32-
publicKey.X, err = config.NewBigIntFromBase64(privateKey.X)
33-
require.NoError(t, err)
34-
publicKey.Y, err = config.NewBigIntFromBase64(privateKey.Y)
35-
require.NoError(t, err)
3651
// Setup in-memory fs
3752
fsys := afero.NewMemMapFs()
3853
require.NoError(t, utils.WriteFile("supabase/config.toml", []byte(`
3954
[auth]
4055
signing_keys_path = "./keys.json"
4156
`), fsys))
42-
testKey, err := json.Marshal([]config.JWK{*privateKey})
57+
testKey, err := json.Marshal([]config.JWK{*privateKeyECDSA})
4358
require.NoError(t, err)
4459
require.NoError(t, utils.WriteFile("supabase/keys.json", testKey, fsys))
4560
// Run test
@@ -48,13 +63,13 @@ func TestGenerateToken(t *testing.T) {
4863
// Check error
4964
assert.NoError(t, err)
5065
token, err := jwt.NewParser().Parse(buf.String(), func(t *jwt.Token) (any, error) {
51-
return &publicKey, nil
66+
return &publicKeyECDSA, nil
5267
})
5368
assert.NoError(t, err)
5469
assert.True(t, token.Valid)
5570
assert.Equal(t, map[string]any{
5671
"alg": "ES256",
57-
"kid": privateKey.KeyID.String(),
72+
"kid": privateKeyECDSA.KeyID,
5873
"typ": "JWT",
5974
}, token.Header)
6075
assert.Equal(t, jwt.MapClaims{
@@ -63,36 +78,116 @@ func TestGenerateToken(t *testing.T) {
6378
}, token.Claims)
6479
})
6580

66-
t.Run("mints legacy JWT", func(t *testing.T) {
81+
t.Run("throws error on unsupported kty", func(t *testing.T) {
82+
claims := jwt.MapClaims{}
83+
// Setup in-memory fs
84+
fsys := afero.NewMemMapFs()
85+
require.NoError(t, utils.WriteFile("supabase/config.toml", []byte(`
86+
[auth]
87+
signing_keys_path = "./keys.json"
88+
`), fsys))
89+
testKey, err := json.Marshal([]config.JWK{{KeyType: "oct"}})
90+
require.NoError(t, err)
91+
require.NoError(t, utils.WriteFile("supabase/keys.json", testKey, fsys))
92+
// Run test
93+
err = Run(context.Background(), claims, io.Discard, fsys)
94+
// Check error
95+
assert.ErrorContains(t, err, "failed to convert JWK to private key: unsupported key type: oct")
96+
})
97+
98+
t.Run("accepts signing key from stdin", func(t *testing.T) {
6799
utils.Config.Auth.SigningKeysPath = ""
68100
utils.Config.Auth.SigningKeys = nil
69101
claims := config.CustomClaims{
70-
RegisteredClaims: jwt.RegisteredClaims{
71-
Subject: uuid.New().String(),
72-
},
73-
Role: "authenticated",
102+
Role: "service_role",
74103
}
75104
// Setup in-memory fs
76105
fsys := afero.NewMemMapFs()
106+
testKey, err := json.Marshal(privateKeyRSA)
107+
require.NoError(t, err)
108+
t.Cleanup(fstest.MockStdin(t, string(testKey)))
77109
// Run test
78110
var buf bytes.Buffer
79-
err := Run(context.Background(), claims, &buf, fsys)
111+
err = Run(context.Background(), claims, &buf, fsys)
80112
// Check error
81113
assert.NoError(t, err)
82114
token, err := jwt.NewParser().Parse(buf.String(), func(t *jwt.Token) (any, error) {
83-
return []byte(utils.Config.Auth.JwtSecret.Value), nil
115+
return &publicKeyRSA, nil
84116
})
85117
assert.NoError(t, err)
86118
assert.True(t, token.Valid)
87119
assert.Equal(t, map[string]any{
88-
"alg": "HS256",
120+
"alg": "RS256",
121+
"kid": privateKeyRSA.KeyID,
89122
"typ": "JWT",
90123
}, token.Header)
91124
assert.Equal(t, jwt.MapClaims{
92-
"exp": float64(1983812996),
93-
"iss": "supabase-demo",
94-
"role": "authenticated",
95-
"sub": claims.Subject,
125+
"role": "service_role",
96126
}, token.Claims)
97127
})
128+
129+
t.Run("throws error on invalid key", func(t *testing.T) {
130+
claims := jwt.MapClaims{}
131+
// Setup in-memory fs
132+
fsys := afero.NewMemMapFs()
133+
t.Cleanup(fstest.MockStdin(t, ""))
134+
// Run test
135+
err = Run(context.Background(), claims, io.Discard, fsys)
136+
// Check error
137+
assert.ErrorContains(t, err, "failed to parse JWK: unexpected end of JSON input")
138+
})
139+
140+
t.Run("accepts kid from stdin", func(t *testing.T) {
141+
claims := jwt.MapClaims{
142+
"role": "postgres",
143+
"sb-role": "mgmt-api",
144+
}
145+
// Setup in-memory fs
146+
fsys := afero.NewMemMapFs()
147+
require.NoError(t, utils.WriteFile("supabase/config.toml", []byte(`
148+
[auth]
149+
signing_keys_path = "./keys.json"
150+
`), fsys))
151+
testKey, err := json.Marshal([]config.JWK{
152+
*privateKeyECDSA,
153+
*privateKeyRSA,
154+
})
155+
require.NoError(t, err)
156+
require.NoError(t, utils.WriteFile("supabase/keys.json", testKey, fsys))
157+
t.Cleanup(fstest.MockStdin(t, privateKeyRSA.KeyID))
158+
// Run test
159+
var buf bytes.Buffer
160+
err = Run(context.Background(), claims, &buf, fsys)
161+
// Check error
162+
assert.NoError(t, err)
163+
token, err := jwt.NewParser().Parse(buf.String(), func(t *jwt.Token) (any, error) {
164+
return &publicKeyRSA, nil
165+
})
166+
assert.NoError(t, err)
167+
assert.True(t, token.Valid)
168+
assert.Equal(t, map[string]any{
169+
"alg": "RS256",
170+
"kid": privateKeyRSA.KeyID,
171+
"typ": "JWT",
172+
}, token.Header)
173+
assert.Equal(t, jwt.MapClaims{
174+
"role": "postgres",
175+
"sb-role": "mgmt-api",
176+
}, token.Claims)
177+
})
178+
179+
t.Run("throws error on missing key", func(t *testing.T) {
180+
claims := jwt.MapClaims{}
181+
// Setup in-memory fs
182+
fsys := afero.NewMemMapFs()
183+
require.NoError(t, utils.WriteFile("supabase/config.toml", []byte(`
184+
[auth]
185+
signing_keys_path = "./keys.json"
186+
`), fsys))
187+
t.Cleanup(fstest.MockStdin(t, "test-key"))
188+
// Run test
189+
err = Run(context.Background(), claims, io.Discard, fsys)
190+
// Check error
191+
assert.ErrorContains(t, err, "signing key not found: test-key")
192+
})
98193
}

0 commit comments

Comments
 (0)