Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 4 additions & 7 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ func (c *conn) Close() error {

if err != nil {
log.Err(err).Msg("databricks: failed to close connection")
return dbsqlerrint.NewRequestError(ctx, dbsqlerr.ErrCloseConnection, err)
return dbsqlerrint.NewBadConnectionError(err)
}
return nil
}
Expand Down Expand Up @@ -168,9 +168,7 @@ func (c *conn) QueryContext(ctx context.Context, query string, args []driver.Nam
return nil, dbsqlerrint.NewExecutionError(ctx, dbsqlerr.ErrQueryExecution, err, opStatusResp)
}

corrId := driverctx.CorrelationIdFromContext(ctx)
rows, err := rows.NewRows(c.id, corrId, exStmtResp.OperationHandle, c.client, c.cfg, exStmtResp.DirectResults)

rows, err := rows.NewRows(ctx, exStmtResp.OperationHandle, c.client, c.cfg, exStmtResp.DirectResults)
return rows, err

}
Expand Down Expand Up @@ -367,7 +365,7 @@ func (c *conn) pollOperation(ctx context.Context, opHandle *cli_service.TOperati
log := logger.WithContext(c.id, corrId, client.SprintGuid(opHandle.OperationId.GUID))
var statusResp *cli_service.TGetOperationStatusResp
ctx = driverctx.NewContextWithConnId(ctx, c.id)
newCtx := driverctx.NewContextWithCorrelationId(driverctx.NewContextWithConnId(context.Background(), c.id), corrId)
newCtx := driverctx.NewContextWithCorrelationId(ctx, corrId)
pollSentinel := sentinel.Sentinel{
OnDoneFn: func(statusResp any) (any, error) {
return statusResp, nil
Expand Down Expand Up @@ -566,7 +564,6 @@ func (c *conn) execStagingOperation(
return nil
}

corrId := driverctx.CorrelationIdFromContext(ctx)
var row driver.Rows
var err error

Expand All @@ -589,7 +586,7 @@ func (c *conn) execStagingOperation(
}

if len(driverctx.StagingPathsFromContext(ctx)) != 0 {
row, err = rows.NewRows(c.id, corrId, exStmtResp.OperationHandle, c.client, c.cfg, exStmtResp.DirectResults)
row, err = rows.NewRows(ctx, exStmtResp.OperationHandle, c.client, c.cfg, exStmtResp.DirectResults)
if err != nil {
return dbsqlerrint.NewDriverError(ctx, "error reading row.", err)
}
Expand Down
16 changes: 7 additions & 9 deletions internal/rows/rows.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,22 +67,22 @@ var _ driver.RowsColumnTypeLength = (*rows)(nil)
var _ dbsqlrows.Rows = (*rows)(nil)

func NewRows(
connId string,
correlationId string,
ctx context.Context,
opHandle *cli_service.TOperationHandle,
client cli_service.TCLIService,
config *config.Config,
directResults *cli_service.TSparkDirectResults,
) (driver.Rows, dbsqlerr.DBError) {

connId := driverctx.ConnIdFromContext(ctx)
correlationId := driverctx.CorrelationIdFromContext(ctx)

var logger *dbsqllog.DBSQLLogger
var ctx context.Context
if opHandle != nil {
logger = dbsqllog.WithContext(connId, correlationId, dbsqlclient.SprintGuid(opHandle.OperationId.GUID))
ctx = driverctx.NewContextWithQueryId(driverctx.NewContextWithCorrelationId(driverctx.NewContextWithConnId(context.Background(), connId), correlationId), dbsqlclient.SprintGuid(opHandle.OperationId.GUID))
ctx = driverctx.NewContextWithQueryId(ctx, dbsqlclient.SprintGuid(opHandle.OperationId.GUID))
} else {
logger = dbsqllog.WithContext(connId, correlationId, "")
ctx = driverctx.NewContextWithCorrelationId(driverctx.NewContextWithConnId(context.Background(), connId), correlationId)
}

if client == nil {
Expand Down Expand Up @@ -140,13 +140,12 @@ func NewRows(
// the operations.
closedOnServer := directResults != nil && directResults.CloseOperation != nil
r.ResultPageIterator = rowscanner.NewResultPageIterator(
ctx,
d,
pageSize,
opHandle,
closedOnServer,
client,
connId,
correlationId,
r.logger(),
)

Expand Down Expand Up @@ -417,9 +416,8 @@ func (r *rows) getResultSetSchema() (*cli_service.TTableSchema, dbsqlerr.DBError
req := cli_service.TGetResultSetMetadataReq{
OperationHandle: r.opHandle,
}
ctx := driverctx.NewContextWithCorrelationId(driverctx.NewContextWithConnId(context.Background(), r.connId), r.correlationId)

resp, err2 := r.client.GetResultSetMetadata(ctx, &req)
resp, err2 := r.client.GetResultSetMetadata(r.ctx, &req)
if err2 != nil {
r.logger().Err(err2).Msg(err2.Error())
return nil, dbsqlerr_int.NewRequestError(r.ctx, errRowsMetadataFetchFailed, err)
Expand Down
25 changes: 13 additions & 12 deletions internal/rows/rowscanner/resultPageIterator.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,32 +45,34 @@ func (d Direction) String() string {

// Create a new result page iterator.
func NewResultPageIterator(
ctx context.Context,
delimiter Delimiter,
maxPageSize int64,
opHandle *cli_service.TOperationHandle,
closedOnServer bool,
client cli_service.TCLIService,
connectionId string,
correlationId string,
logger *dbsqllog.DBSQLLogger,
) ResultPageIterator {

// delimiter and hasMoreRows are used to set up the point in the paginated
// result set that this iterator starts from.
return &resultPageIterator{
ctx: ctx,
Delimiter: delimiter,
isFinished: closedOnServer,
maxPageSize: maxPageSize,
opHandle: opHandle,
closedOnServer: closedOnServer,
client: client,
connectionId: connectionId,
correlationId: correlationId,
connectionId: driverctx.ConnIdFromContext(ctx),
correlationId: driverctx.CorrelationIdFromContext(ctx),
logger: logger,
}
}

type resultPageIterator struct {
ctx context.Context

// Gives the parameters of the current result page
Delimiter

Expand Down Expand Up @@ -167,15 +169,14 @@ func (rpf *resultPageIterator) getNextPage() (*cli_service.TFetchResultsResp, er
nextPageStartRow := rpf.Start() + rpf.Count()

rpf.logger.Debug().Msgf("databricks: fetching result page for row %d", nextPageStartRow)
ctx := driverctx.NewContextWithCorrelationId(driverctx.NewContextWithConnId(context.Background(), rpf.connectionId), rpf.correlationId)

// Keep fetching in the appropriate direction until we have the expected page.
var fetchResult *cli_service.TFetchResultsResp
var b bool
for b = rpf.Contains(nextPageStartRow); !b; b = rpf.Contains(nextPageStartRow) {

direction := rpf.Direction(nextPageStartRow)
err := rpf.checkDirectionValid(ctx, direction)
err := rpf.checkDirectionValid(direction)
if err != nil {
return nil, err
}
Expand All @@ -190,10 +191,10 @@ func (rpf *resultPageIterator) getNextPage() (*cli_service.TFetchResultsResp, er
IncludeResultSetMetadata: &includeResultSetMetadata,
}

fetchResult, err = rpf.client.FetchResults(ctx, &req)
fetchResult, err = rpf.client.FetchResults(rpf.ctx, &req)
if err != nil {
rpf.logger.Err(err).Msg("databricks: Rows instance failed to retrieve results")
return nil, dbsqlerrint.NewRequestError(ctx, errRowsResultFetchFailed, err)
return nil, dbsqlerrint.NewRequestError(rpf.ctx, errRowsResultFetchFailed, err)
}

rpf.Delimiter = NewDelimiter(fetchResult.Results.StartRowOffset, CountRows(fetchResult.Results))
Expand All @@ -218,7 +219,7 @@ func (rpf *resultPageIterator) Close() (err error) {
OperationHandle: rpf.opHandle,
}

_, err = rpf.client.CloseOperation(context.Background(), &req)
_, err = rpf.client.CloseOperation(rpf.ctx, &req)
return err
}
}
Expand Down Expand Up @@ -283,11 +284,11 @@ func CountRows(rowSet *cli_service.TRowSet) int64 {
}

// Check if trying to fetch in the specified direction creates an error condition.
func (rpf *resultPageIterator) checkDirectionValid(ctx context.Context, direction Direction) error {
func (rpf *resultPageIterator) checkDirectionValid(direction Direction) error {
if direction == DirBack {
// can't fetch rows previous to the start
if rpf.Start() == 0 {
return dbsqlerrint.NewDriverError(ctx, ErrRowsFetchPriorToStart, nil)
return dbsqlerrint.NewDriverError(rpf.ctx, ErrRowsFetchPriorToStart, nil)
}
} else if direction == DirForward {
// can't fetch past the end of the query results
Expand All @@ -296,7 +297,7 @@ func (rpf *resultPageIterator) checkDirectionValid(ctx context.Context, directio
}
} else {
rpf.logger.Error().Msgf(errRowsUnandledFetchDirection(direction.String()))
return dbsqlerrint.NewDriverError(ctx, errRowsUnandledFetchDirection(direction.String()), nil)
return dbsqlerrint.NewDriverError(rpf.ctx, errRowsUnandledFetchDirection(direction.String()), nil)
}
return nil
}
Expand Down