Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 25 additions & 7 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
dbsqlerrint "github.com/databricks/databricks-sql-go/internal/errors"
"github.com/databricks/databricks-sql-go/internal/rows"
"github.com/databricks/databricks-sql-go/internal/sentinel"
"github.com/databricks/databricks-sql-go/internal/thrift_protocol"
"github.com/databricks/databricks-sql-go/logger"
"github.com/pkg/errors"
)
Expand Down Expand Up @@ -285,14 +286,30 @@ func (c *conn) executeStatement(ctx context.Context, query string, args []driver
Statement: query,
RunAsync: true,
QueryTimeout: int64(c.cfg.QueryTimeout / time.Second),
GetDirectResults: &cli_service.TSparkGetDirectResults{
}

// Check protocol version for feature support
serverProtocolVersion := c.session.ServerProtocolVersion

// Add direct results if supported
if thrift_protocol.SupportsDirectResults(serverProtocolVersion) {
req.GetDirectResults = &cli_service.TSparkGetDirectResults{
MaxRows: int64(c.cfg.MaxRows),
},
CanDecompressLZ4Result_: &c.cfg.UseLz4Compression,
Parameters: parameters,
}
}

if c.cfg.UseArrowBatches {
// Add LZ4 compression if supported and enabled
if thrift_protocol.SupportsLz4Compression(serverProtocolVersion) && c.cfg.UseLz4Compression {
req.CanDecompressLZ4Result_ = &c.cfg.UseLz4Compression
}

// Add cloud fetch if supported and enabled
if thrift_protocol.SupportsCloudFetch(serverProtocolVersion) && c.cfg.UseCloudFetch {
req.CanDownloadResult_ = &c.cfg.UseCloudFetch
}

// Add Arrow support if supported and enabled
if thrift_protocol.SupportsArrow(serverProtocolVersion) && c.cfg.UseArrowBatches {
req.CanReadArrowResult_ = &c.cfg.UseArrowBatches
req.UseArrowNativeTypes = &cli_service.TSparkArrowTypes{
DecimalAsArrow: &c.cfg.UseArrowNativeDecimal,
Expand All @@ -302,8 +319,9 @@ func (c *conn) executeStatement(ctx context.Context, query string, args []driver
}
}

if c.cfg.UseCloudFetch {
req.CanDownloadResult_ = &c.cfg.UseCloudFetch
// Add parameters if supported and provided
if thrift_protocol.SupportsParameterizedQueries(serverProtocolVersion) && len(parameters) > 0 {
req.Parameters = parameters
}

resp, err := c.client.ExecuteStatement(ctx, &req)
Expand Down
162 changes: 162 additions & 0 deletions connection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"github.com/databricks/databricks-sql-go/internal/cli_service"
"github.com/databricks/databricks-sql-go/internal/client"
"github.com/databricks/databricks-sql-go/internal/config"
"github.com/databricks/databricks-sql-go/internal/thrift_protocol"
"github.com/stretchr/testify/assert"
)

Expand Down Expand Up @@ -331,6 +332,167 @@ func TestConn_executeStatement(t *testing.T) {

}

func TestConn_executeStatement_ProtocolFeatures(t *testing.T) {
t.Parallel()

protocols := []cli_service.TProtocolVersion{
cli_service.TProtocolVersion_SPARK_CLI_SERVICE_PROTOCOL_V1,
cli_service.TProtocolVersion_SPARK_CLI_SERVICE_PROTOCOL_V2,
cli_service.TProtocolVersion_SPARK_CLI_SERVICE_PROTOCOL_V3,
cli_service.TProtocolVersion_SPARK_CLI_SERVICE_PROTOCOL_V4,
cli_service.TProtocolVersion_SPARK_CLI_SERVICE_PROTOCOL_V5,
cli_service.TProtocolVersion_SPARK_CLI_SERVICE_PROTOCOL_V6,
cli_service.TProtocolVersion_SPARK_CLI_SERVICE_PROTOCOL_V7,
cli_service.TProtocolVersion_SPARK_CLI_SERVICE_PROTOCOL_V8,
}

testCases := []struct {
cfg *config.Config
supportsDirectResults func(version cli_service.TProtocolVersion) bool
supportsLz4Compression func(version cli_service.TProtocolVersion) bool
supportsCloudFetch func(version cli_service.TProtocolVersion) bool
supportsArrow func(version cli_service.TProtocolVersion) bool
supportsParameterizedQueries func(version cli_service.TProtocolVersion) bool
hasParameters bool
}{
{
cfg: func() *config.Config {
cfg := config.WithDefaults()
cfg.UseLz4Compression = true
cfg.UseCloudFetch = true
cfg.UseArrowBatches = true
cfg.UseArrowNativeDecimal = true
cfg.UseArrowNativeTimestamp = true
cfg.UseArrowNativeComplexTypes = true
cfg.UseArrowNativeIntervalTypes = true
return cfg
}(),
supportsDirectResults: thrift_protocol.SupportsDirectResults,
supportsLz4Compression: thrift_protocol.SupportsLz4Compression,
supportsCloudFetch: thrift_protocol.SupportsCloudFetch,
supportsArrow: thrift_protocol.SupportsArrow,
supportsParameterizedQueries: thrift_protocol.SupportsParameterizedQueries,
hasParameters: true,
},
{
cfg: func() *config.Config {
cfg := config.WithDefaults()
cfg.UseLz4Compression = false
cfg.UseCloudFetch = false
cfg.UseArrowBatches = false
return cfg
}(),
supportsDirectResults: thrift_protocol.SupportsDirectResults,
supportsLz4Compression: thrift_protocol.SupportsLz4Compression,
supportsCloudFetch: thrift_protocol.SupportsCloudFetch,
supportsArrow: thrift_protocol.SupportsArrow,
supportsParameterizedQueries: thrift_protocol.SupportsParameterizedQueries,
hasParameters: false,
},
}

for _, tc := range testCases {
for _, version := range protocols {
t.Run(fmt.Sprintf("protocol_v%d_withParams_%v", version, tc.hasParameters), func(t *testing.T) {
var capturedReq *cli_service.TExecuteStatementReq
executeStatement := func(ctx context.Context, req *cli_service.TExecuteStatementReq) (r *cli_service.TExecuteStatementResp, err error) {
capturedReq = req
executeStatementResp := &cli_service.TExecuteStatementResp{
Status: &cli_service.TStatus{
StatusCode: cli_service.TStatusCode_SUCCESS_STATUS,
},
OperationHandle: &cli_service.TOperationHandle{
OperationId: &cli_service.THandleIdentifier{
GUID: []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
Secret: []byte("secret"),
},
},
DirectResults: &cli_service.TSparkDirectResults{
OperationStatus: &cli_service.TGetOperationStatusResp{
Status: &cli_service.TStatus{
StatusCode: cli_service.TStatusCode_SUCCESS_STATUS,
},
OperationState: cli_service.TOperationStatePtr(cli_service.TOperationState_FINISHED_STATE),
},
},
}
return executeStatementResp, nil
}

session := getTestSession()
session.ServerProtocolVersion = version

testClient := &client.TestClient{
FnExecuteStatement: executeStatement,
}

testConn := &conn{
session: session,
client: testClient,
cfg: tc.cfg,
}

var args []driver.NamedValue
if tc.hasParameters {
args = []driver.NamedValue{
{Name: "param1", Value: "value1"},
}
}

_, err := testConn.executeStatement(context.Background(), "SELECT 1", args)
assert.NoError(t, err)

// Verify direct results
hasDirectResults := tc.supportsDirectResults(version)
assert.Equal(t, hasDirectResults, capturedReq.GetDirectResults != nil, "Direct results should be enabled if protocol supports it")

// Verify LZ4 compression
shouldHaveLz4 := tc.supportsLz4Compression(version) && tc.cfg.UseLz4Compression
if shouldHaveLz4 {
assert.NotNil(t, capturedReq.CanDecompressLZ4Result_)
assert.True(t, *capturedReq.CanDecompressLZ4Result_)
} else {
assert.Nil(t, capturedReq.CanDecompressLZ4Result_)
}

// Verify cloud fetch
shouldHaveCloudFetch := tc.supportsCloudFetch(version) && tc.cfg.UseCloudFetch
if shouldHaveCloudFetch {
assert.NotNil(t, capturedReq.CanDownloadResult_)
assert.True(t, *capturedReq.CanDownloadResult_)
} else {
assert.Nil(t, capturedReq.CanDownloadResult_)
}

// Verify Arrow support
shouldHaveArrow := tc.supportsArrow(version) && tc.cfg.UseArrowBatches
if shouldHaveArrow {
assert.NotNil(t, capturedReq.CanReadArrowResult_)
assert.True(t, *capturedReq.CanReadArrowResult_)
assert.NotNil(t, capturedReq.UseArrowNativeTypes)
assert.Equal(t, tc.cfg.UseArrowNativeDecimal, *capturedReq.UseArrowNativeTypes.DecimalAsArrow)
assert.Equal(t, tc.cfg.UseArrowNativeTimestamp, *capturedReq.UseArrowNativeTypes.TimestampAsArrow)
assert.Equal(t, tc.cfg.UseArrowNativeComplexTypes, *capturedReq.UseArrowNativeTypes.ComplexTypesAsArrow)
assert.Equal(t, tc.cfg.UseArrowNativeIntervalTypes, *capturedReq.UseArrowNativeTypes.IntervalTypesAsArrow)
} else {
assert.Nil(t, capturedReq.CanReadArrowResult_)
assert.Nil(t, capturedReq.UseArrowNativeTypes)
}

// Verify parameters
shouldHaveParams := tc.supportsParameterizedQueries(version) && tc.hasParameters
if shouldHaveParams {
assert.NotNil(t, capturedReq.Parameters)
assert.Len(t, capturedReq.Parameters, 1)
} else if tc.hasParameters {
// Even if we have parameters but protocol doesn't support it, we shouldn't set them
assert.Nil(t, capturedReq.Parameters)
}
})
}
}
}

func TestConn_pollOperation(t *testing.T) {
t.Parallel()
t.Run("pollOperation returns finished state response when query finishes", func(t *testing.T) {
Expand Down
2 changes: 1 addition & 1 deletion connector.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
}
log := logger.WithContext(conn.id, driverctx.CorrelationIdFromContext(ctx), "")

log.Info().Msgf("connect: host=%s port=%d httpPath=%s", c.cfg.Host, c.cfg.Port, c.cfg.HTTPPath)
log.Info().Msgf("connect: host=%s port=%d httpPath=%s serverProtocolVersion=0x%X", c.cfg.Host, c.cfg.Port, c.cfg.HTTPPath, session.ServerProtocolVersion)

for k, v := range c.cfg.SessionParams {
setStmt := fmt.Sprintf("SET `%s` = `%s`;", k, v)
Expand Down
5 changes: 5 additions & 0 deletions internal/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,11 @@ func (tsc *ThriftServiceClient) OpenSession(ctx context.Context, req *cli_servic
return resp, err
}

// Log the server protocol version
if resp != nil {
log.Debug().Msgf("Server protocol version: 0x%X", resp.ServerProtocolVersion)
}

recordResult(ctx, resp)

return resp, CheckStatus(resp)
Expand Down
2 changes: 1 addition & 1 deletion internal/config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -648,7 +648,7 @@ func TestConfig_DeepCopy(t *testing.T) {
DriverVersion: "0.9.0",
ThriftProtocol: "binary",
ThriftTransport: "http",
ThriftProtocolVersion: cli_service.TProtocolVersion_SPARK_CLI_SERVICE_PROTOCOL_V6,
ThriftProtocolVersion: cli_service.TProtocolVersion_SPARK_CLI_SERVICE_PROTOCOL_V8,
ThriftDebugClientProtocol: false,
}

Expand Down
46 changes: 46 additions & 0 deletions internal/thrift_protocol/protocol_feature_util.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package thrift_protocol

import "github.com/databricks/databricks-sql-go/internal/cli_service"

// Feature checks
// SupportsDirectResults checks if the server protocol version supports direct results
// Supported in SPARK_CLI_SERVICE_PROTOCOL_V1 and above
func SupportsDirectResults(version cli_service.TProtocolVersion) bool {
return version >= cli_service.TProtocolVersion_SPARK_CLI_SERVICE_PROTOCOL_V1
}

// SupportsLz4Compression checks if the server protocol version supports LZ4 compression
// Supported in SPARK_CLI_SERVICE_PROTOCOL_V6 and above
func SupportsLz4Compression(version cli_service.TProtocolVersion) bool {
return version >= cli_service.TProtocolVersion_SPARK_CLI_SERVICE_PROTOCOL_V6
}

// SupportsCloudFetch checks if the server protocol version supports cloud fetch
// Supported in SPARK_CLI_SERVICE_PROTOCOL_V3 and above
func SupportsCloudFetch(version cli_service.TProtocolVersion) bool {
return version >= cli_service.TProtocolVersion_SPARK_CLI_SERVICE_PROTOCOL_V3
}

// SupportsArrow checks if the server protocol version supports Arrow format
// Supported in SPARK_CLI_SERVICE_PROTOCOL_V5 and above
func SupportsArrow(version cli_service.TProtocolVersion) bool {
return version >= cli_service.TProtocolVersion_SPARK_CLI_SERVICE_PROTOCOL_V5
}

// SupportsCompressedArrow checks if the server protocol version supports compressed Arrow format
// Supported in SPARK_CLI_SERVICE_PROTOCOL_V6 and above
func SupportsCompressedArrow(version cli_service.TProtocolVersion) bool {
return version >= cli_service.TProtocolVersion_SPARK_CLI_SERVICE_PROTOCOL_V6
}

// SupportsParameterizedQueries checks if the server protocol version supports parameterized queries
// Supported in SPARK_CLI_SERVICE_PROTOCOL_V8 and above
func SupportsParameterizedQueries(version cli_service.TProtocolVersion) bool {
return version >= cli_service.TProtocolVersion_SPARK_CLI_SERVICE_PROTOCOL_V8
}

// SupportsMultipleCatalogs checks if the server protocol version supports multiple catalogs
// Supported in SPARK_CLI_SERVICE_PROTOCOL_V4 and above
func SupportsMultipleCatalogs(version cli_service.TProtocolVersion) bool {
return version >= cli_service.TProtocolVersion_SPARK_CLI_SERVICE_PROTOCOL_V4
}
Loading
Loading