Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions connector.go
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,20 @@ func WithMaxDownloadThreads(numThreads int) ConnOption {
}
}

// WithMaxBytesPerFile sets up maximum bytes per file for cloud fetch. Default is 100MB.
func WithMaxBytesPerFile(maxBytes int64) ConnOption {
return func(c *config.Config) {
c.MaxBytesPerFile = maxBytes
}
}

// WithUseLz4Compression sets up whether to use lz4 compression for cloud fetch. Default is false.
func WithUseLz4Compression(useLz4 bool) ConnOption {
return func(c *config.Config) {
c.UseLz4Compression = useLz4
}
}

// Setup of Oauth M2m authentication
func WithClientCredentials(clientID, clientSecret string) ConnOption {
return func(c *config.Config) {
Expand Down
11 changes: 8 additions & 3 deletions connector_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@ import (
"testing"
"time"

"github.com/databricks/databricks-sql-go/auth/pat"
"github.com/databricks/databricks-sql-go/internal/client"
"github.com/databricks/databricks-sql-go/internal/config"
"github.com/hashicorp/go-retryablehttp"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/databricks/databricks-sql-go/auth/pat"
"github.com/databricks/databricks-sql-go/internal/client"
"github.com/databricks/databricks-sql-go/internal/config"
)

func TestNewConnector(t *testing.T) {
Expand Down Expand Up @@ -40,6 +41,7 @@ func TestNewConnector(t *testing.T) {
WithTransport(roundTripper),
WithCloudFetch(true),
WithMaxDownloadThreads(15),
WithMaxBytesPerFile(10*1024*1024),
WithSkipTLSHostVerify(),
)
expectedCloudFetchConfig := config.CloudFetchConfig{
Expand All @@ -48,6 +50,7 @@ func TestNewConnector(t *testing.T) {
MaxFilesInMemory: 10,
MinTimeToExpiry: 0 * time.Second,
CloudFetchSpeedThresholdMbps: 0.1,
MaxBytesPerFile: 10 * 1024 * 1024,
}
expectedUserConfig := config.UserConfig{
Host: host,
Expand Down Expand Up @@ -95,6 +98,7 @@ func TestNewConnector(t *testing.T) {
MaxFilesInMemory: 10,
MinTimeToExpiry: 0 * time.Second,
CloudFetchSpeedThresholdMbps: 0.1,
MaxBytesPerFile: 100 * 1024 * 1024,
}
expectedUserConfig := config.UserConfig{
Host: host,
Expand Down Expand Up @@ -137,6 +141,7 @@ func TestNewConnector(t *testing.T) {
MaxFilesInMemory: 10,
MinTimeToExpiry: 0 * time.Second,
CloudFetchSpeedThresholdMbps: 0.1,
MaxBytesPerFile: 100 * 1024 * 1024,
}
expectedUserConfig := config.UserConfig{
Host: host,
Expand Down
24 changes: 22 additions & 2 deletions internal/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@ import (
"strings"
"time"

dbsqlerr "github.com/databricks/databricks-sql-go/errors"
"github.com/pkg/errors"

dbsqlerr "github.com/databricks/databricks-sql-go/errors"

"github.com/databricks/databricks-sql-go/auth"
"github.com/databricks/databricks-sql-go/auth/noop"
"github.com/databricks/databricks-sql-go/auth/oauth/m2m"
Expand Down Expand Up @@ -174,7 +175,6 @@ func (ucfg UserConfig) WithDefaults() UserConfig {
if ucfg.RetryWaitMax == 0 {
ucfg.RetryWaitMax = 30 * time.Second
}
ucfg.UseLz4Compression = false
ucfg.CloudFetchConfig = CloudFetchConfig{}.WithDefaults()

return ucfg
Expand Down Expand Up @@ -272,6 +272,20 @@ func ParseDSN(dsn string) (UserConfig, error) {
ucfg.MaxDownloadThreads = numThreads
}

if maxBytesPerFile, ok, err := params.extractAsInt("maxBytesPerFile"); ok {
if err != nil {
return UserConfig{}, err
}
ucfg.MaxBytesPerFile = int64(maxBytesPerFile)
}

if useLz4Compression, ok, err := params.extractAsBool("useLz4Compression"); ok {
if err != nil {
return UserConfig{}, err
}
ucfg.UseLz4Compression = useLz4Compression
}

// for timezone we do a case insensitive key match.
// We use getNoCase because we want to leave timezone in the params so that it will also
// be used as a session param.
Expand Down Expand Up @@ -469,6 +483,7 @@ type CloudFetchConfig struct {
MaxFilesInMemory int
MinTimeToExpiry time.Duration
CloudFetchSpeedThresholdMbps float64 // Minimum download speed in MBps before WARN logging (default: 0.1)
MaxBytesPerFile int64
}

func (cfg CloudFetchConfig) WithDefaults() CloudFetchConfig {
Expand All @@ -490,6 +505,10 @@ func (cfg CloudFetchConfig) WithDefaults() CloudFetchConfig {
cfg.CloudFetchSpeedThresholdMbps = 0.1
}

if cfg.MaxBytesPerFile <= 0 {
cfg.MaxBytesPerFile = 100 * 1024 * 1024 // 100 MB
}

return cfg
}

Expand All @@ -500,5 +519,6 @@ func (cfg CloudFetchConfig) DeepCopy() CloudFetchConfig {
MaxFilesInMemory: cfg.MaxFilesInMemory,
MinTimeToExpiry: cfg.MinTimeToExpiry,
CloudFetchSpeedThresholdMbps: cfg.CloudFetchSpeedThresholdMbps,
MaxBytesPerFile: cfg.MaxBytesPerFile,
}
}
38 changes: 21 additions & 17 deletions internal/config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -251,14 +251,15 @@ func TestParseConfig(t *testing.T) {
MaxDownloadThreads: 10,
MaxFilesInMemory: 10,
CloudFetchSpeedThresholdMbps: 0.1,
MaxBytesPerFile: 100 * 1024 * 1024,
},
},
wantURL: "https://example.cloud.databricks.com:8000/sql/1.0/endpoints/12346a5b5b0e123b",
wantErr: false,
},
{
name: "with useCloudFetch and maxDownloadThreads",
args: args{dsn: "token:supersecret@example.cloud.databricks.com:8000/sql/1.0/endpoints/12346a5b5b0e123b?useCloudFetch=true&maxDownloadThreads=15"},
args: args{dsn: "token:supersecret@example.cloud.databricks.com:8000/sql/1.0/endpoints/12346a5b5b0e123b?useCloudFetch=true&maxDownloadThreads=15&maxBytesPerFile=10485760"},
wantCfg: UserConfig{
Protocol: "https",
Host: "example.cloud.databricks.com",
Expand All @@ -276,35 +277,38 @@ func TestParseConfig(t *testing.T) {
MaxDownloadThreads: 15,
MaxFilesInMemory: 10,
CloudFetchSpeedThresholdMbps: 0.1,
MaxBytesPerFile: 10 * 1024 * 1024,
},
},
wantURL: "https://example.cloud.databricks.com:8000/sql/1.0/endpoints/12346a5b5b0e123b",
wantErr: false,
},
{
name: "with everything",
args: args{dsn: "token:supersecret2@example.cloud.databricks.com:8000/sql/1.0/endpoints/12346a5b5b0e123a?catalog=default&schema=system&userAgentEntry=partner-name&timeout=100&maxRows=1000&ANSI_MODE=true&useCloudFetch=true&maxDownloadThreads=15"},
args: args{dsn: "token:supersecret2@example.cloud.databricks.com:8000/sql/1.0/endpoints/12346a5b5b0e123a?catalog=default&schema=system&userAgentEntry=partner-name&timeout=100&maxRows=1000&ANSI_MODE=true&useCloudFetch=true&maxDownloadThreads=15&maxBytesPerFile=10485760&useLz4Compression=true"},
wantCfg: UserConfig{
Protocol: "https",
Host: "example.cloud.databricks.com",
Port: 8000,
AccessToken: "supersecret2",
Authenticator: &pat.PATAuth{AccessToken: "supersecret2"},
HTTPPath: "/sql/1.0/endpoints/12346a5b5b0e123a",
QueryTimeout: 100 * time.Second,
MaxRows: 1000,
UserAgentEntry: "partner-name",
Catalog: "default",
Schema: "system",
SessionParams: map[string]string{"ANSI_MODE": "true"},
RetryMax: 4,
RetryWaitMin: 1 * time.Second,
RetryWaitMax: 30 * time.Second,
Protocol: "https",
Host: "example.cloud.databricks.com",
Port: 8000,
AccessToken: "supersecret2",
Authenticator: &pat.PATAuth{AccessToken: "supersecret2"},
HTTPPath: "/sql/1.0/endpoints/12346a5b5b0e123a",
QueryTimeout: 100 * time.Second,
MaxRows: 1000,
UserAgentEntry: "partner-name",
Catalog: "default",
Schema: "system",
SessionParams: map[string]string{"ANSI_MODE": "true"},
RetryMax: 4,
RetryWaitMin: 1 * time.Second,
RetryWaitMax: 30 * time.Second,
UseLz4Compression: true,
CloudFetchConfig: CloudFetchConfig{
UseCloudFetch: true,
MaxDownloadThreads: 15,
MaxFilesInMemory: 10,
CloudFetchSpeedThresholdMbps: 0.1,
MaxBytesPerFile: 10 * 1024 * 1024,
},
},
wantURL: "https://example.cloud.databricks.com:8000/sql/1.0/endpoints/12346a5b5b0e123a",
Expand Down
113 changes: 79 additions & 34 deletions internal/rows/arrowbased/batchloader.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,19 @@ import (
"context"
"fmt"
"io"
"net/http"
"strings"
"sync"
"time"

"github.com/databricks/databricks-sql-go/internal/config"
"github.com/databricks/databricks-sql-go/internal/rows/rowscanner"
"github.com/pierrec/lz4/v4"
"github.com/pkg/errors"

"net/http"
"github.com/databricks/databricks-sql-go/internal/config"
"github.com/databricks/databricks-sql-go/internal/rows/rowscanner"

"github.com/apache/arrow/go/v12/arrow/ipc"

dbsqlerr "github.com/databricks/databricks-sql-go/errors"
"github.com/databricks/databricks-sql-go/internal/cli_service"
dbsqlerrint "github.com/databricks/databricks-sql-go/internal/errors"
Expand All @@ -40,11 +42,13 @@ func NewCloudIPCStreamIterator(
startRowOffset: startRowOffset,
pendingLinks: NewQueue[cli_service.TSparkArrowResultLink](),
downloadTasks: NewQueue[cloudFetchDownloadTask](),
results: make(chan cloudFetchDownloadTaskResult, cfg.MaxDownloadThreads*2),
}

for _, link := range files {
bi.pendingLinks.Enqueue(link)
}
go bi.startDownloads()

return bi, nil
}
Expand Down Expand Up @@ -140,49 +144,90 @@ type cloudIPCStreamIterator struct {
startRowOffset int64
pendingLinks Queue[cli_service.TSparkArrowResultLink]
downloadTasks Queue[cloudFetchDownloadTask]
results chan cloudFetchDownloadTaskResult
wg sync.WaitGroup
}

var _ IPCStreamIterator = (*cloudIPCStreamIterator)(nil)

func (bi *cloudIPCStreamIterator) Next() (io.Reader, error) {
for (bi.downloadTasks.Len() < bi.cfg.MaxDownloadThreads) && (bi.pendingLinks.Len() > 0) {
link := bi.pendingLinks.Dequeue()
logger.Debug().Msgf(
"CloudFetch: schedule link at offset %d row count %d",
link.StartRowOffset,
link.RowCount,
)
func (bi *cloudIPCStreamIterator) startDownloads() {
defer func() {
bi.wg.Wait()
close(bi.results)
}()

// Start tasks while we have capacity and work
for bi.pendingLinks.Len() > 0 || bi.downloadTasks.Len() > 0 {
// Fill up to MaxDownloadThreads
for bi.pendingLinks.Len() > 0 && bi.downloadTasks.Len() < bi.cfg.MaxDownloadThreads {
link := bi.pendingLinks.Dequeue()

cancelCtx, cancelFn := context.WithCancel(bi.ctx)
task := &cloudFetchDownloadTask{
ctx: cancelCtx,
cancel: cancelFn,
useLz4Compression: bi.cfg.UseLz4Compression,
link: link,
resultChan: make(chan cloudFetchDownloadTaskResult),
minTimeToExpiry: bi.cfg.MinTimeToExpiry,
}
bi.downloadTasks.Enqueue(task)

bi.wg.Add(1)
go func(t *cloudFetchDownloadTask) {
defer bi.wg.Done()
defer func() {
if r := recover(); r != nil {
// Don’t block indefinitely on shutdown
select {
case bi.results <- cloudFetchDownloadTaskResult{nil, fmt.Errorf("panic: %v", r)}:
case <-bi.ctx.Done():
}
}
}()

// Do the real work inside the goroutine
t.Run() // starts and completes the download
res, err := t.GetResult() // or blocks until ready

// Publish result unless shutting down
select {
case bi.results <- cloudFetchDownloadTaskResult{res, err}:
case <-bi.ctx.Done():
}

// Mark this task as no longer active
_ = bi.downloadTasks.Dequeue()
}(task)
}

cancelCtx, cancelFn := context.WithCancel(bi.ctx)
task := &cloudFetchDownloadTask{
ctx: cancelCtx,
cancel: cancelFn,
useLz4Compression: bi.cfg.UseLz4Compression,
link: link,
resultChan: make(chan cloudFetchDownloadTaskResult),
minTimeToExpiry: bi.cfg.MinTimeToExpiry,
speedThresholdMbps: bi.cfg.CloudFetchSpeedThresholdMbps,
// Wait for either: context canceled, or a small state change window
select {
case <-bi.ctx.Done():
// Cancel all in-flight tasks and exit
for bi.downloadTasks.Len() > 0 {
t := bi.downloadTasks.Dequeue()
t.cancel()
}
return
default:
// Yield briefly; avoids tight spin while allowing quick refills
time.Sleep(10 * time.Millisecond)
}
task.Run()
bi.downloadTasks.Enqueue(task)
}
}

task := bi.downloadTasks.Dequeue()
if task == nil {
func (bi *cloudIPCStreamIterator) Next() (io.Reader, error) {
result, ok := <-bi.results
if !ok {
// Channel closed, no more results
return nil, io.EOF
}

data, err := task.GetResult()

// once we've got an errored out task - cancel the remaining ones
if err != nil {
if result.err != nil {
bi.Close()
return nil, err
return nil, result.err
}

// explicitly call cancel function on successfully completed task to avoid context leak
task.cancel()
return data, nil
return result.data, nil
}

func (bi *cloudIPCStreamIterator) HasNext() bool {
Expand Down
4 changes: 2 additions & 2 deletions internal/rows/arrowbased/batchloader_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@ import (
"testing"
"time"

"github.com/pkg/errors"

dbsqlerr "github.com/databricks/databricks-sql-go/errors"
"github.com/databricks/databricks-sql-go/internal/cli_service"
"github.com/databricks/databricks-sql-go/internal/config"
"github.com/pkg/errors"

"github.com/apache/arrow/go/v12/arrow"
"github.com/apache/arrow/go/v12/arrow/array"
Expand Down Expand Up @@ -222,7 +223,6 @@ func TestCloudFetchIterator(t *testing.T) {

assert.True(t, bi.HasNext())
assert.Equal(t, cbi.pendingLinks.Len(), len(links))
assert.Equal(t, cbi.downloadTasks.Len(), 0)

// set handler for the first link, which returns some data
handler = func(w http.ResponseWriter, r *http.Request) {
Expand Down