Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 39 additions & 6 deletions cmd/gen.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
package cmd

import (
"encoding/json"
"os"
"os/signal"
"strings"
"time"

env "github.com/Netflix/go-env"
"github.com/go-errors/errors"
"github.com/go-viper/mapstructure/v2"
"github.com/golang-jwt/jwt/v5"
"github.com/spf13/afero"
"github.com/spf13/cobra"
Expand Down Expand Up @@ -124,17 +127,18 @@ Supported algorithms:
claims config.CustomClaims
expiry time.Time
validFor time.Duration
payload string

genJWTCmd = &cobra.Command{
Use: "bearer-jwt",
Short: "Generate a Bearer Auth JWT for accessing Data API",
Args: cobra.NoArgs,
RunE: func(cmd *cobra.Command, args []string) error {
if expiry.IsZero() {
expiry = time.Now().Add(validFor)
custom := jwt.MapClaims{}
if err := parseClaims(custom); err != nil {
return err
}
claims.ExpiresAt = jwt.NewNumericDate(expiry)
return bearerjwt.Run(cmd.Context(), claims, os.Stdout, afero.NewOsFs())
return bearerjwt.Run(cmd.Context(), custom, os.Stdout, afero.NewOsFs())
},
}
)
Expand Down Expand Up @@ -166,12 +170,41 @@ func init() {
genCmd.AddCommand(genSigningKeyCmd)
tokenFlags := genJWTCmd.Flags()
tokenFlags.StringVar(&claims.Role, "role", "", "Postgres role to use.")
cobra.CheckErr(genJWTCmd.MarkFlagRequired("role"))
tokenFlags.StringVar(&claims.Subject, "sub", "", "User ID to impersonate.")
genJWTCmd.Flag("sub").DefValue = "anonymous"
tokenFlags.TimeVar(&expiry, "exp", time.Time{}, []string{time.RFC3339}, "Expiry timestamp for this token.")
tokenFlags.DurationVar(&validFor, "valid-for", time.Minute*30, "Validity duration for this token.")
genJWTCmd.MarkFlagsMutuallyExclusive("exp", "valid-for")
cobra.CheckErr(genJWTCmd.MarkFlagRequired("role"))
tokenFlags.StringVar(&payload, "payload", "{}", "Custom claims in JSON format.")
genCmd.AddCommand(genJWTCmd)
rootCmd.AddCommand(genCmd)
}

func parseClaims(custom jwt.MapClaims) error {
// Initialise default claims
now := time.Now()
if expiry.IsZero() {
expiry = now.Add(validFor)
} else {
now = expiry.Add(-validFor)
}
claims.IssuedAt = jwt.NewNumericDate(now)
claims.ExpiresAt = jwt.NewNumericDate(expiry)
// Set is_anonymous = true for authenticated role without explicit user ID
if strings.EqualFold(claims.Role, "authenticated") && len(claims.Subject) == 0 {
claims.IsAnon = true
}
// Override with custom claims
if dec, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{
TagName: "json",
Result: &custom,
}); err != nil {
return errors.Errorf("failed to init decoder: %w", err)
} else if err := dec.Decode(claims); err != nil {
return errors.Errorf("failed to decode claims: %w", err)
}
if err := json.Unmarshal([]byte(payload), &custom); err != nil {
return errors.Errorf("failed to parse payload: %w", err)
}
return nil
}
73 changes: 56 additions & 17 deletions internal/gen/bearerjwt/bearerjwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,41 +2,80 @@ package bearerjwt

import (
"context"
"encoding/json"
"fmt"
"io"
"os"
"strings"

"github.com/go-errors/errors"
"github.com/golang-jwt/jwt/v5"
"github.com/spf13/afero"
"github.com/supabase/cli/internal/utils"
"github.com/supabase/cli/internal/utils/flags"
"github.com/supabase/cli/pkg/config"
)

func Run(ctx context.Context, claims config.CustomClaims, w io.Writer, fsys afero.Fs) error {
func Run(ctx context.Context, claims jwt.Claims, w io.Writer, fsys afero.Fs) error {
if err := flags.LoadConfig(fsys); err != nil {
return err
}
// Set is_anonymous = true for authenticated role without explicit user ID
if strings.EqualFold(claims.Role, "authenticated") && len(claims.Subject) == 0 {
claims.IsAnon = true
}
// Use the first signing key that passes validation
for _, k := range utils.Config.Auth.SigningKeys {
fmt.Fprintln(os.Stderr, "Using signing key ID:", k.KeyID.String())
if token, err := config.GenerateAsymmetricJWT(k, claims); err != nil {
fmt.Fprintln(os.Stderr, err)
} else {
fmt.Fprintln(w, token)
return nil
}
key, err := getSigningKey(ctx)
if err != nil {
return err
}
fmt.Fprintln(os.Stderr, "Using legacy JWT secret...")
token, err := claims.NewToken().SignedString([]byte(utils.Config.Auth.JwtSecret.Value))
token, err := config.GenerateAsymmetricJWT(*key, claims)
if err != nil {
return errors.Errorf("failed to generate auth token: %w", err)
return err
}
fmt.Fprintln(w, token)
return nil
}

func getSigningKey(ctx context.Context) (*config.JWK, error) {
console := utils.NewConsole()
if len(utils.Config.Auth.SigningKeysPath) == 0 {
title := "Enter your signing key in JWK format: "
kid, err := console.PromptText(ctx, title)
if err != nil {
return nil, err
}
key := config.JWK{}
if err := json.Unmarshal([]byte(kid), &key); err != nil {
return nil, errors.Errorf("failed to parse JWK: %w", err)
}
return &key, nil
}
// Allow manual kid entry on CI
if !console.IsTTY {
title := "Enter the kid of your signing key (or leave blank to use the first one): "
kid, err := console.PromptText(ctx, title)
if err != nil {
return nil, err
}
for i, k := range utils.Config.Auth.SigningKeys {
if k.KeyID == kid {
return &utils.Config.Auth.SigningKeys[i], nil
}
}
if len(kid) == 0 && len(utils.Config.Auth.SigningKeys) > 0 {
return &utils.Config.Auth.SigningKeys[0], nil
}
return nil, errors.Errorf("signing key not found: %s", kid)
}
// Let user choose from a list of signing keys
items := make([]utils.PromptItem, len(utils.Config.Auth.SigningKeys))
for i, k := range utils.Config.Auth.SigningKeys {
items[i] = utils.PromptItem{
Index: i,
Summary: k.KeyID,
Details: fmt.Sprintf("%s (%s)", k.Algorithm, strings.Join(k.KeyOps, ",")),
}
}
choice, err := utils.PromptChoice(ctx, "Select a signing key:", items)
if err != nil {
return nil, err
}
fmt.Fprintln(os.Stderr, "Selected key ID:", choice.Summary)
return &utils.Config.Auth.SigningKeys[choice.Index], nil
}
147 changes: 121 additions & 26 deletions internal/gen/bearerjwt/bearerjwt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,41 +5,56 @@ import (
"context"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rsa"
_ "embed"
"encoding/json"
"io"
"testing"

"github.com/golang-jwt/jwt/v5"
"github.com/google/uuid"
"github.com/spf13/afero"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/supabase/cli/internal/gen/signingkeys"
"github.com/supabase/cli/internal/testing/fstest"
"github.com/supabase/cli/internal/utils"
"github.com/supabase/cli/pkg/config"
)

func TestGenerateToken(t *testing.T) {
// Setup private key - ECDSA
privateKeyECDSA, err := signingkeys.GeneratePrivateKey(config.AlgES256)
require.NoError(t, err)
// Setup public key for validation
publicKeyECDSA := ecdsa.PublicKey{Curve: elliptic.P256()}
publicKeyECDSA.X, err = config.NewBigIntFromBase64(privateKeyECDSA.X)
require.NoError(t, err)
publicKeyECDSA.Y, err = config.NewBigIntFromBase64(privateKeyECDSA.Y)
require.NoError(t, err)

// Setup private key - RSA
privateKeyRSA, err := signingkeys.GeneratePrivateKey(config.AlgRS256)
require.NoError(t, err)
// Setup public key for validation
publicKeyRSA := rsa.PublicKey{}
publicKeyRSA.N, err = config.NewBigIntFromBase64(privateKeyRSA.Modulus)
require.NoError(t, err)
bigE, err := config.NewBigIntFromBase64(privateKeyRSA.Exponent)
require.NoError(t, err)
publicKeyRSA.E = int(bigE.Int64())

t.Run("mints custom JWT", func(t *testing.T) {
claims := config.CustomClaims{
Role: "authenticated",
IsAnon: true,
Role: "authenticated",
}
// Setup private key
privateKey, err := signingkeys.GeneratePrivateKey(config.AlgES256)
require.NoError(t, err)
// Setup public key for validation
publicKey := ecdsa.PublicKey{Curve: elliptic.P256()}
publicKey.X, err = config.NewBigIntFromBase64(privateKey.X)
require.NoError(t, err)
publicKey.Y, err = config.NewBigIntFromBase64(privateKey.Y)
require.NoError(t, err)
// Setup in-memory fs
fsys := afero.NewMemMapFs()
require.NoError(t, utils.WriteFile("supabase/config.toml", []byte(`
[auth]
signing_keys_path = "./keys.json"
`), fsys))
testKey, err := json.Marshal([]config.JWK{*privateKey})
testKey, err := json.Marshal([]config.JWK{*privateKeyECDSA})
require.NoError(t, err)
require.NoError(t, utils.WriteFile("supabase/keys.json", testKey, fsys))
// Run test
Expand All @@ -48,13 +63,13 @@ func TestGenerateToken(t *testing.T) {
// Check error
assert.NoError(t, err)
token, err := jwt.NewParser().Parse(buf.String(), func(t *jwt.Token) (any, error) {
return &publicKey, nil
return &publicKeyECDSA, nil
})
assert.NoError(t, err)
assert.True(t, token.Valid)
assert.Equal(t, map[string]any{
"alg": "ES256",
"kid": privateKey.KeyID.String(),
"kid": privateKeyECDSA.KeyID,
"typ": "JWT",
}, token.Header)
assert.Equal(t, jwt.MapClaims{
Expand All @@ -63,36 +78,116 @@ func TestGenerateToken(t *testing.T) {
}, token.Claims)
})

t.Run("mints legacy JWT", func(t *testing.T) {
t.Run("throws error on unsupported kty", func(t *testing.T) {
claims := jwt.MapClaims{}
// Setup in-memory fs
fsys := afero.NewMemMapFs()
require.NoError(t, utils.WriteFile("supabase/config.toml", []byte(`
[auth]
signing_keys_path = "./keys.json"
`), fsys))
testKey, err := json.Marshal([]config.JWK{{KeyType: "oct"}})
require.NoError(t, err)
require.NoError(t, utils.WriteFile("supabase/keys.json", testKey, fsys))
// Run test
err = Run(context.Background(), claims, io.Discard, fsys)
// Check error
assert.ErrorContains(t, err, "failed to convert JWK to private key: unsupported key type: oct")
})

t.Run("accepts signing key from stdin", func(t *testing.T) {
utils.Config.Auth.SigningKeysPath = ""
utils.Config.Auth.SigningKeys = nil
claims := config.CustomClaims{
RegisteredClaims: jwt.RegisteredClaims{
Subject: uuid.New().String(),
},
Role: "authenticated",
Role: "service_role",
}
// Setup in-memory fs
fsys := afero.NewMemMapFs()
testKey, err := json.Marshal(privateKeyRSA)
require.NoError(t, err)
t.Cleanup(fstest.MockStdin(t, string(testKey)))
// Run test
var buf bytes.Buffer
err := Run(context.Background(), claims, &buf, fsys)
err = Run(context.Background(), claims, &buf, fsys)
// Check error
assert.NoError(t, err)
token, err := jwt.NewParser().Parse(buf.String(), func(t *jwt.Token) (any, error) {
return []byte(utils.Config.Auth.JwtSecret.Value), nil
return &publicKeyRSA, nil
})
assert.NoError(t, err)
assert.True(t, token.Valid)
assert.Equal(t, map[string]any{
"alg": "HS256",
"alg": "RS256",
"kid": privateKeyRSA.KeyID,
"typ": "JWT",
}, token.Header)
assert.Equal(t, jwt.MapClaims{
"exp": float64(1983812996),
"iss": "supabase-demo",
"role": "authenticated",
"sub": claims.Subject,
"role": "service_role",
}, token.Claims)
})

t.Run("throws error on invalid key", func(t *testing.T) {
claims := jwt.MapClaims{}
// Setup in-memory fs
fsys := afero.NewMemMapFs()
t.Cleanup(fstest.MockStdin(t, ""))
// Run test
err = Run(context.Background(), claims, io.Discard, fsys)
// Check error
assert.ErrorContains(t, err, "failed to parse JWK: unexpected end of JSON input")
})

t.Run("accepts kid from stdin", func(t *testing.T) {
claims := jwt.MapClaims{
"role": "postgres",
"sb-role": "mgmt-api",
}
// Setup in-memory fs
fsys := afero.NewMemMapFs()
require.NoError(t, utils.WriteFile("supabase/config.toml", []byte(`
[auth]
signing_keys_path = "./keys.json"
`), fsys))
testKey, err := json.Marshal([]config.JWK{
*privateKeyECDSA,
*privateKeyRSA,
})
require.NoError(t, err)
require.NoError(t, utils.WriteFile("supabase/keys.json", testKey, fsys))
t.Cleanup(fstest.MockStdin(t, privateKeyRSA.KeyID))
// Run test
var buf bytes.Buffer
err = Run(context.Background(), claims, &buf, fsys)
// Check error
assert.NoError(t, err)
token, err := jwt.NewParser().Parse(buf.String(), func(t *jwt.Token) (any, error) {
return &publicKeyRSA, nil
})
assert.NoError(t, err)
assert.True(t, token.Valid)
assert.Equal(t, map[string]any{
"alg": "RS256",
"kid": privateKeyRSA.KeyID,
"typ": "JWT",
}, token.Header)
assert.Equal(t, jwt.MapClaims{
"role": "postgres",
"sb-role": "mgmt-api",
}, token.Claims)
})

t.Run("throws error on missing key", func(t *testing.T) {
claims := jwt.MapClaims{}
// Setup in-memory fs
fsys := afero.NewMemMapFs()
require.NoError(t, utils.WriteFile("supabase/config.toml", []byte(`
[auth]
signing_keys_path = "./keys.json"
`), fsys))
t.Cleanup(fstest.MockStdin(t, "test-key"))
// Run test
err = Run(context.Background(), claims, io.Discard, fsys)
// Check error
assert.ErrorContains(t, err, "signing key not found: test-key")
})
}
Loading