Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 102 additions & 0 deletions schema/arrow.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
package schema

import (
"crypto/sha1"
"time"

"github.com/apache/arrow-go/v18/arrow"
"github.com/apache/arrow-go/v18/arrow/array"
"github.com/apache/arrow-go/v18/arrow/memory"
"github.com/cloudquery/plugin-sdk/v4/types"
"github.com/google/uuid"
)

const (
Expand Down Expand Up @@ -40,3 +47,98 @@ func (s Schemas) SchemaByName(name string) *arrow.Schema {
}
return nil
}

func hashRecord(record arrow.Record) arrow.Array {
numRows := int(record.NumRows())
fields := record.Schema().Fields()
hashArray := types.NewUUIDBuilder(memory.DefaultAllocator)
hashArray.Reserve(numRows)
for row := range numRows {
rowHash := sha1.New()
for col := 0; col < int(record.NumCols()); col++ {
fieldName := fields[col].Name
rowHash.Write([]byte(fieldName))
value := record.Column(col).ValueStr(row)
_, _ = rowHash.Write([]byte(value))
}
// This part ensures that we conform to the UUID spec
hashArray.Append(uuid.NewSHA1(uuid.NameSpaceURL, rowHash.Sum(nil)))
}
return hashArray.NewArray()
}

func nullUUIDsForRecord(numRows int) arrow.Array {
uuidArray := types.NewUUIDBuilder(memory.DefaultAllocator)
uuidArray.AppendNulls(numRows)
return uuidArray.NewArray()
}

func StringArrayFromValue(value string, nRows int) arrow.Array {
arrayBuilder := array.NewStringBuilder(memory.DefaultAllocator)
arrayBuilder.Reserve(nRows)
for range nRows {
arrayBuilder.AppendString(value)
}
return arrayBuilder.NewArray()
}

func TimestampArrayFromTime(t time.Time, unit arrow.TimeUnit, timeZone string, nRows int) (arrow.Array, error) {
ts, err := arrow.TimestampFromTime(t, unit)
if err != nil {
return nil, err
}
arrayBuilder := array.NewTimestampBuilder(memory.DefaultAllocator, &arrow.TimestampType{Unit: unit, TimeZone: timeZone})
arrayBuilder.Reserve(nRows)
for range nRows {
arrayBuilder.Append(ts)
}
return arrayBuilder.NewArray(), nil
}

func ReplaceFieldInRecord(src arrow.Record, fieldName string, field arrow.Array) (record arrow.Record, err error) {
fieldIndexes := src.Schema().FieldIndices(fieldName)
for i := range fieldIndexes {
record, err = src.SetColumn(fieldIndexes[i], field)
if err != nil {
return nil, err
}
}
return record, nil
}

func AddInternalColumnsToRecord(record arrow.Record, cqClientIDValue string) (arrow.Record, error) {
schema := record.Schema()
nRows := int(record.NumRows())

newFields := []arrow.Field{}
newColumns := []arrow.Array{}

var err error
if !schema.HasField(CqIDColumn.Name) {
cqID := hashRecord(record)
newFields = append(newFields, CqIDColumn.ToArrowField())
newColumns = append(newColumns, cqID)
}
if !schema.HasField(CqParentIDColumn.Name) {
cqParentID := nullUUIDsForRecord(nRows)
newFields = append(newFields, CqParentIDColumn.ToArrowField())
newColumns = append(newColumns, cqParentID)
}

clientIDArray := StringArrayFromValue(cqClientIDValue, nRows)
if !schema.HasField(CqClientIDColumn.Name) {
newFields = append(newFields, CqClientIDColumn.ToArrowField())
newColumns = append(newColumns, clientIDArray)
} else {
record, err = ReplaceFieldInRecord(record, CqClientIDColumn.Name, clientIDArray)
if err != nil {
return nil, err
}
}

allFields := append(schema.Fields(), newFields...)
allColumns := append(record.Columns(), newColumns...)
metadata := schema.Metadata()
newSchema := arrow.NewSchema(allFields, &metadata)
return array.NewRecord(newSchema, allColumns, int64(nRows)), nil
}
131 changes: 131 additions & 0 deletions schema/arrow_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,15 @@ package schema
import (
"fmt"
"strings"
"testing"

"github.com/apache/arrow-go/v18/arrow"
"github.com/apache/arrow-go/v18/arrow/array"
"github.com/apache/arrow-go/v18/arrow/memory"
"github.com/cloudquery/plugin-sdk/v4/types"
"github.com/google/uuid"
"github.com/samber/lo"
"github.com/stretchr/testify/require"
)

func RecordDiff(l arrow.Record, r arrow.Record) string {
Expand All @@ -31,3 +37,128 @@ func RecordDiff(l arrow.Record, r arrow.Record) string {
}
return sb.String()
}

func buildTestRecord(withClientIDValue string) arrow.Record {
testFields := []arrow.Field{
{Name: "id", Type: arrow.PrimitiveTypes.Int64, Nullable: true},
{Name: "name", Type: arrow.BinaryTypes.String, Nullable: true},
{Name: "value", Type: arrow.PrimitiveTypes.Float64, Nullable: true},
{Name: "bool", Type: arrow.FixedWidthTypes.Boolean, Nullable: true},
{Name: "uuid", Type: types.UUID, Nullable: true},
}
if withClientIDValue != "" {
testFields = append(testFields, CqClientIDColumn.ToArrowField())
}
schema := arrow.NewSchema(testFields, nil)

testValuesCount := 10
builders := []array.Builder{
array.NewInt64Builder(memory.DefaultAllocator),
array.NewStringBuilder(memory.DefaultAllocator),
array.NewFloat64Builder(memory.DefaultAllocator),
array.NewBooleanBuilder(memory.DefaultAllocator),
types.NewUUIDBuilder(memory.DefaultAllocator),
}
for _, builder := range builders {
builder.Reserve(testValuesCount)
switch b := builder.(type) {
case *array.Int64Builder:
for i := range testValuesCount {
b.Append(int64(i))
}
case *array.StringBuilder:
for i := range testValuesCount {
b.AppendString(fmt.Sprintf("test%d", i))
}
case *array.Float64Builder:
for i := range testValuesCount {
b.Append(float64(i))
}
case *array.BooleanBuilder:
for i := range testValuesCount {
b.Append(i%2 == 0)
}
case *types.UUIDBuilder:
for i := range testValuesCount {
b.Append(uuid.NewSHA1(uuid.NameSpaceURL, []byte(fmt.Sprintf("test%d", i))))
}
}
}
if withClientIDValue != "" {
builder := array.NewStringBuilder(memory.DefaultAllocator)
builder.Reserve(testValuesCount)
for range testValuesCount {
builder.AppendString(withClientIDValue)
}
builders = append(builders, builder)
}
values := lo.Map(builders, func(builder array.Builder, _ int) arrow.Array {
return builder.NewArray()
})
return array.NewRecord(schema, values, int64(testValuesCount))
}

func TestAddInternalColumnsToRecord(t *testing.T) {
tests := []struct {
name string
record arrow.Record
cqClientIDValue string
expectedNewColumns int64
}{
{
name: "add _cq_id,_cq_parent_id,_cq_client_id",
record: buildTestRecord(""),
cqClientIDValue: "new_client_id",
expectedNewColumns: 3,
},
{
name: "add cq_client_id,cq_id replace existing _cq_client_id",
record: buildTestRecord("existing_client_id"),
cqClientIDValue: "new_client_id",
expectedNewColumns: 2,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := AddInternalColumnsToRecord(tt.record, tt.cqClientIDValue)
require.NoError(t, err)
require.Equal(t, tt.record.NumRows(), got.NumRows())
require.Equal(t, tt.record.NumCols()+tt.expectedNewColumns, got.NumCols())

gotSchema := got.Schema()
cqIDFields := gotSchema.FieldIndices(CqIDColumn.Name)
require.Len(t, cqIDFields, 1)

cqParentIDFields := gotSchema.FieldIndices(CqParentIDColumn.Name)
require.Len(t, cqParentIDFields, 1)

cqClientIDFields := gotSchema.FieldIndices(CqClientIDColumn.Name)
require.Len(t, cqClientIDFields, 1)

cqIDArray := got.Column(cqIDFields[0])
require.Equal(t, types.UUID, cqIDArray.DataType())
require.Equal(t, tt.record.NumRows(), int64(cqIDArray.Len()))

cqParentIDArray := got.Column(cqParentIDFields[0])
require.Equal(t, types.UUID, cqParentIDArray.DataType())
require.Equal(t, tt.record.NumRows(), int64(cqParentIDArray.Len()))

cqClientIDArray := got.Column(cqClientIDFields[0])
require.Equal(t, arrow.BinaryTypes.String, cqClientIDArray.DataType())
require.Equal(t, tt.record.NumRows(), int64(cqClientIDArray.Len()))

for i := range cqIDArray.Len() {
cqID := cqIDArray.GetOneForMarshal(i).(uuid.UUID)
require.NotEmpty(t, cqID)
}
for i := range cqParentIDArray.Len() {
cqParentID := cqParentIDArray.GetOneForMarshal(i)
require.Nil(t, cqParentID)
}
for i := range cqClientIDArray.Len() {
cqClientID := cqClientIDArray.GetOneForMarshal(i).(string)
require.Equal(t, tt.cqClientIDValue, cqClientID)
}
})
}
}
Loading