Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
416 changes: 302 additions & 114 deletions pkg/cmd/alpha/tunnel.go

Large diffs are not rendered by default.

194 changes: 194 additions & 0 deletions pkg/cmd/alpha/tunnel_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
package alpha_test

import (
"context"
"crypto/rand"
"crypto/tls"
"errors"
"log/slog"
"net"
"net/netip"
"testing"
"time"

"github.com/apoxy-dev/icx"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"gvisor.dev/gvisor/pkg/tcpip"

"github.com/apoxy-dev/apoxy/pkg/cmd"
"github.com/apoxy-dev/apoxy/pkg/cryptoutils"
"github.com/apoxy-dev/apoxy/pkg/netstack"
"github.com/apoxy-dev/apoxy/pkg/tunnel"
"github.com/apoxy-dev/apoxy/pkg/tunnel/connection"
"github.com/apoxy-dev/apoxy/pkg/tunnel/controllers"
"github.com/apoxy-dev/apoxy/pkg/tunnel/hasher"
"github.com/apoxy-dev/apoxy/pkg/tunnel/router"
)

func TestTunnelRun(t *testing.T) {
if testing.Verbose() {
slog.SetLogLoggerLevel(slog.LevelDebug)
}

var connected bool

// onConnect assigns VNI and overlay address so handleConnect can proceed.
onConnect := func(ctx context.Context, agent string, conn controllers.Connection) error {
// Choose a deterministic VNI for the test.
conn.SetVNI(101)
conn.SetOverlayAddress("10.0.0.2/32")
t.Logf("onConnect called, agent=%s", agent)
if agent == "test-agent" {
connected = true
}
return nil
}

onDisconnect := func(ctx context.Context, agent, id string) error {
t.Logf("onDisconnect called, agent=%s id=%s", agent, id)
return nil
}

r, _, stop := startRelay(t, "letmein", onConnect, onDisconnect)
t.Cleanup(stop)

ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second)
t.Cleanup(cancel)

cmd := cmd.RootCmd
cmd.SetArgs([]string{
"alpha", "tunnel", "run",
"--agent", "test-agent",
"--name", "test-tunnel",
"--relay-addrs", r.Address(),
"--token", "letmein",
"--insecure-skip-verify",
})
cmd.SilenceUsage = true
err := cmd.ExecuteContext(ctx)
if err != nil && errors.Is(err, context.DeadlineExceeded) {
err = nil // expected on timeout
}
require.NoError(t, err)

require.True(t, connected, "expected to be connected")

// TODO: verify traffic routing through the tunnel
}

func startRelay(t *testing.T, token string, onConnect func(context.Context, string, controllers.Connection) error, onDisconnect func(context.Context, string, string) error) (*tunnel.Relay, tls.Certificate, func()) {
t.Helper()

pc, err := net.ListenPacket("udp", "127.0.0.1:0")
require.NoError(t, err)

caCert, serverCert, err := cryptoutils.GenerateSelfSignedTLSCert("localhost")
require.NoError(t, err)

h, err := icx.NewHandler(icx.WithLocalAddr(netstack.ToFullAddress(netip.MustParseAddrPort("127.0.0.1:6081"))),
icx.WithVirtMAC(tcpip.GetRandMacAddr()))
require.NoError(t, err)

idKey := make([]byte, 32)
_, err = rand.Read(idKey)
require.NoError(t, err)

idHasher := hasher.NewHasher(idKey)

rtr := &mockRouter{}

rtr.On("Start", mock.Anything).Return(nil)
rtr.On("Close").Return(nil)

r := tunnel.NewRelay("relay-it", pc, serverCert, h, idHasher, rtr)
r.SetCredentials("test-tunnel", token)
r.SetOnConnect(onConnect)
r.SetOnDisconnect(onDisconnect)

ctx, cancel := context.WithCancel(context.Background())

done := make(chan struct{})
go func() {
if err := r.Start(ctx); err != nil {
t.Errorf("Relay stopped with error: %v", err)
}
close(done)
}()

// Give the server a brief moment to bind and start serving.
time.Sleep(150 * time.Millisecond)

stop := func() {
cancel()
select {
case <-done:
case <-time.After(30 * time.Second):
// if shutdown hangs, tests will fail below anyway
}
_ = pc.Close()
}

return r, caCert, stop
}

type mockRouter struct {
mock.Mock
}

func (m *mockRouter) Start(ctx context.Context) error {
args := m.Called(ctx)
return args.Error(0)
}

func (m *mockRouter) AddAddr(addr netip.Prefix, tun connection.Connection) error {
args := m.Called(addr, tun)
return args.Error(0)
}

func (m *mockRouter) ListAddrs() ([]netip.Prefix, error) {
args := m.Called()
var addrs []netip.Prefix
if v := args.Get(0); v != nil {
addrs = v.([]netip.Prefix)
}
return addrs, args.Error(1)
}

func (m *mockRouter) DelAddr(addr netip.Prefix) error {
args := m.Called(addr)
return args.Error(0)
}

func (m *mockRouter) AddRoute(dst netip.Prefix) error {
args := m.Called(dst)
return args.Error(0)
}

func (m *mockRouter) DelRoute(dst netip.Prefix) error {
args := m.Called(dst)
return args.Error(0)
}

func (m *mockRouter) ListRoutes() ([]router.TunnelRoute, error) {
args := m.Called()
var routes []router.TunnelRoute
if v := args.Get(0); v != nil {
routes = v.([]router.TunnelRoute)
}
return routes, args.Error(1)
}

func (m *mockRouter) LocalAddresses() ([]netip.Prefix, error) {
args := m.Called()
var addrs []netip.Prefix
if v := args.Get(0); v != nil {
addrs = v.([]netip.Prefix)
}
return addrs, args.Error(1)
}

func (m *mockRouter) Close() error {
args := m.Called()
return args.Error(0)
}
2 changes: 1 addition & 1 deletion pkg/cmd/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,5 +46,5 @@ If your CLI is already authenticated this will return information about your ses

func init() {
authCmd.PersistentFlags().BoolVar(&checkOnly, "check", false, "only check the authentication status")
rootCmd.AddCommand(authCmd)
RootCmd.AddCommand(authCmd)
}
2 changes: 1 addition & 1 deletion pkg/cmd/dev.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ func init() {
BoolVar(&useSubprocess, "use-subprocess", false, "Use subprocess for apiserver and backplane.")
devCmd.PersistentFlags().
StringVar(&clickhouseAddr, "clickhouse-addr", "", "ClickHouse address (host only, port 9000 will be used).")
rootCmd.AddCommand(devCmd)
RootCmd.AddCommand(devCmd)
}

func maybeNamespaced(un *unstructured.Unstructured) string {
Expand Down
2 changes: 1 addition & 1 deletion pkg/cmd/edgefunction.go
Original file line number Diff line number Diff line change
Expand Up @@ -242,5 +242,5 @@ func init() {
BoolVar(&showEdgeFunctionLabels, "show-labels", false, "Print the edge function's labels.")

alphaEdgeFunctionCmd.AddCommand(getAlphaEdgeFunctionCmd, listAlphaEdgeFunctionCmd, createAlphaEdgeFunctionCmd, deleteAlphaEdgeFunctionCmd)
rootCmd.AddCommand(alphaEdgeFunctionCmd)
RootCmd.AddCommand(alphaEdgeFunctionCmd)
}
2 changes: 1 addition & 1 deletion pkg/cmd/k8s.go
Original file line number Diff line number Diff line change
Expand Up @@ -204,5 +204,5 @@ func init() {
installK8sCmd.Flags().Bool("force", false, "If true, forces value overwrites (See: https://v1-28.docs.kubernetes.io/docs/reference/using-api/server-side-apply/#conflicts)")
k8sCmd.AddCommand(installK8sCmd)

rootCmd.AddCommand(k8sCmd)
RootCmd.AddCommand(k8sCmd)
}
2 changes: 1 addition & 1 deletion pkg/cmd/logs.go
Original file line number Diff line number Diff line change
Expand Up @@ -190,5 +190,5 @@ func init() {
logsCmd.Flags().BoolP("follow", "f", false, "Follow logs in real-time")
logsCmd.Flags().DurationP("since", "", 0, "Show logs since a given duration (e.g. 5m, 1h)")
logsCmd.Flags().StringP("since-time", "", "", "Show logs from a given date (e.g. 2019-01-01T00:00:00Z)")
rootCmd.AddCommand(logsCmd)
RootCmd.AddCommand(logsCmd)
}
2 changes: 1 addition & 1 deletion pkg/cmd/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -231,5 +231,5 @@ func init() {
// TODO: add flags for proxy config as raw envoy config

alphaProxyCmd.AddCommand(getAlphaProxyCmd, listAlphaProxyCmd, createAlphaProxyCmd, deleteAlphaProxyCmd)
rootCmd.AddCommand(alphaProxyCmd)
RootCmd.AddCommand(alphaProxyCmd)
}
24 changes: 12 additions & 12 deletions pkg/cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ import (
"github.com/apoxy-dev/apoxy/pkg/cmd/tunnel"
)

// rootCmd represents the base command when called without any subcommands
var rootCmd = &cobra.Command{
// RootCmd represents the base command when called without any subcommands
var RootCmd = &cobra.Command{
Use: "apoxy",
Short: "Apoxy helps you expose, explore, and evolve your APIs and services.",
Long: `The Apoxy CLI is the quickest and easiest way to create and control Apoxy proxies.
Expand All @@ -43,19 +43,19 @@ Start by creating an account on https://apoxy.dev and logging in with 'apoxy aut
// ExecuteContext executes root command with context.
// This is called by main.main(). It only needs to happen once to the rootCmd.
func ExecuteContext(ctx context.Context) error {
return rootCmd.ExecuteContext(ctx)
return RootCmd.ExecuteContext(ctx)
}

func init() {
rootCmd.PersistentFlags().StringVar(&config.ConfigFile, "config", "", "Config file (default is $HOME/.apoxy/config.yaml).")
rootCmd.PersistentFlags().BoolVar(&config.AlsoLogToStderr, "alsologtostderr", false, "Log to standard error as well as files.")
rootCmd.PersistentFlags().BoolVarP(&config.Verbose, "verbose", "v", false, "Enable verbose output.")
rootCmd.PersistentFlags().BoolVar(&config.LocalMode, "local", false, "Run in local mode.")
rootCmd.PersistentFlags().StringVar(&config.ProjectID, "project", "", "The project ID to use.")
rootCmd.PersistentFlags().BoolVar(&config.PprofEnabled, "pprof", false, "Enable pprof HTTP server on :6060.")
RootCmd.PersistentFlags().StringVar(&config.ConfigFile, "config", "", "Config file (default is $HOME/.apoxy/config.yaml).")
RootCmd.PersistentFlags().BoolVar(&config.AlsoLogToStderr, "alsologtostderr", false, "Log to standard error as well as files.")
RootCmd.PersistentFlags().BoolVarP(&config.Verbose, "verbose", "v", false, "Enable verbose output.")
RootCmd.PersistentFlags().BoolVar(&config.LocalMode, "local", false, "Run in local mode.")
RootCmd.PersistentFlags().StringVar(&config.ProjectID, "project", "", "The project ID to use.")
RootCmd.PersistentFlags().BoolVar(&config.PprofEnabled, "pprof", false, "Enable pprof HTTP server on :6060.")

rootCmd.AddCommand(alpha.Cmd())
rootCmd.AddCommand(tunnel.Cmd())
RootCmd.AddCommand(alpha.Cmd())
RootCmd.AddCommand(tunnel.Cmd())
}

// GenerateDocs generates the docs in the docs folder.
Expand All @@ -67,7 +67,7 @@ func GenerateDocs() {
return fmt.Sprintf("#%s", s)
}
emptyStr := func(s string) string { return "" }
files, err := genMarkdownTreeCustom(rootCmd, "./docs", emptyStr, anchorLinks)
files, err := genMarkdownTreeCustom(RootCmd, "./docs", emptyStr, anchorLinks)
if err != nil {
panic(err)
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/cmd/upgrade.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,5 +49,5 @@ var upgradeCmd = &cobra.Command{
}

func init() {
rootCmd.AddCommand(upgradeCmd)
RootCmd.AddCommand(upgradeCmd)
}
2 changes: 1 addition & 1 deletion pkg/cmd/version.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,5 @@ var versionCmd = &cobra.Command{
}

func init() {
rootCmd.AddCommand(versionCmd)
RootCmd.AddCommand(versionCmd)
}
28 changes: 28 additions & 0 deletions pkg/tunnel/api/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,14 @@ import (
"encoding/json"
"fmt"
"io"
"log/slog"
"net"
"net/http"
"net/url"
"path"
"time"

"github.com/quic-go/quic-go"
"github.com/quic-go/quic-go/http3"
)

Expand All @@ -38,6 +41,9 @@ type ClientOptions struct {
TLSConfig *tls.Config
// Timeout for each request. Defaults to 10s if not set.
Timeout time.Duration
// PacketConn is an optional UDP PacketConn to use for QUIC connections.
// If nil, a new UDP socket will be created for each connection.
PacketConn net.PacketConn
}

func NewClient(opts ClientOptions) (*Client, error) {
Expand Down Expand Up @@ -68,6 +74,28 @@ func NewClient(opts ClientOptions) (*Client, error) {

t := &http3.Transport{
TLSClientConfig: opts.TLSConfig,
QUICConfig: &quic.Config{
Tracer: newConnectionTracer,
},
}

if opts.PacketConn != nil {
quicTransport := &quic.Transport{
Conn: opts.PacketConn,
}
t.Dial = func(ctx context.Context, addr string, tlsConf *tls.Config, quicConf *quic.Config) (quic.EarlyConnection, error) {
udpAddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
return nil, err
}
slog.Debug("Dialing QUIC", slog.String("addr", addr), slog.String("udp", udpAddr.String()))
qc, err := quicTransport.DialEarly(ctx, udpAddr, tlsConf, quicConf)
if err != nil {
return nil, err
}
slog.Debug("Dialed QUIC", slog.String("addr", addr), slog.String("udp", udpAddr.String()))
return qc, nil
}
}

hc := &http.Client{
Expand Down
Loading