Skip to content

Commit 6f295b6

Browse files
committed
#53 Adding support for UserAgent and UserId
### What changes were proposed in this pull request? This patch adds proper support for the `user_agent` and `user_id` property of the connection string. In addition, it provides sensible default values for these parameters when connecting from a Go application. ### Why are the changes needed? Compatibility ### Does this PR introduce _any_ user-facing change? Compatibility ### How was this patch tested? Added tests Closes #105 from grundprinzip/53. Authored-by: Martin Grund <martin.grund@databricks.com> Signed-off-by: Martin Grund <martin.grund@databricks.com>
1 parent a24841f commit 6f295b6

File tree

6 files changed

+139
-87
lines changed

6 files changed

+139
-87
lines changed

spark/client/channel/channel.go

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,13 @@ import (
2424
"fmt"
2525
"net"
2626
"net/url"
27+
"os"
28+
"runtime"
2729
"strconv"
2830
"strings"
2931

32+
"github.com/apache/spark-connect-go/v35/spark"
33+
3034
"github.com/google/uuid"
3135

3236
"google.golang.org/grpc/credentials/insecure"
@@ -56,6 +60,9 @@ type Builder interface {
5660
// SessionId identifies the client side session identifier. This value must be a UUID formatted
5761
// as a string.
5862
SessionId() string
63+
// UserAgent identifies the user agent string that is passed as part of the request. It contains
64+
// information about the operating system, Go version etc.
65+
UserAgent() string
5966
}
6067

6168
// BaseBuilder is used to parse the different parameters of the connection
@@ -69,6 +76,7 @@ type BaseBuilder struct {
6976
user string
7077
headers map[string]string
7178
sessionId string
79+
userAgent string
7280
}
7381

7482
func (cb *BaseBuilder) Host() string {
@@ -95,6 +103,10 @@ func (cb *BaseBuilder) SessionId() string {
95103
return cb.sessionId
96104
}
97105

106+
func (cb *BaseBuilder) UserAgent() string {
107+
return cb.userAgent
108+
}
109+
98110
// Build finalizes the creation of the gprc.ClientConn by creating a GRPC channel
99111
// with the necessary options extracted from the connection string. For
100112
// TLS connections, this function will load the system certificates.
@@ -178,6 +190,7 @@ func NewBuilder(connection string) (*BaseBuilder, error) {
178190
port: port,
179191
headers: map[string]string{},
180192
sessionId: uuid.NewString(),
193+
userAgent: "",
181194
}
182195

183196
elements := strings.Split(u.Path, ";")
@@ -190,10 +203,33 @@ func NewBuilder(connection string) (*BaseBuilder, error) {
190203
cb.user = props[1]
191204
} else if props[0] == "session_id" {
192205
cb.sessionId = props[1]
206+
} else if props[0] == "user_agent" {
207+
cb.userAgent = props[1]
193208
} else {
194209
cb.headers[props[0]] = props[1]
195210
}
196211
}
197212
}
213+
214+
// Set default user ID if not set.
215+
if cb.user == "" {
216+
cb.user = os.Getenv("USER")
217+
if cb.user == "" {
218+
cb.user = "na"
219+
}
220+
}
221+
222+
// Update the user agent if it is not set or set to a custom value.
223+
val := os.Getenv("SPARK_CONNECT_USER_AGENT")
224+
if cb.userAgent == "" && val != "" {
225+
cb.userAgent = os.Getenv("SPARK_CONNECT_USER_AGENT")
226+
} else if cb.userAgent == "" {
227+
cb.userAgent = "_SPARK_CONNECT_GO"
228+
}
229+
230+
// In addition, to the specified user agent, we need to append information about the
231+
// host encoded as user agent components.
232+
cb.userAgent = fmt.Sprintf("%s spark/%s os/%s go/%s", cb.userAgent, spark.Version(), runtime.GOOS, runtime.Version())
233+
198234
return cb, nil
199235
}

spark/client/channel/channel_test.go

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,3 +96,19 @@ func TestChannelBuildConnect(t *testing.T) {
9696
assert.Nil(t, err, "no error for proper connection")
9797
assert.NotNil(t, conn)
9898
}
99+
100+
func TestChannelBulder_UserAgent(t *testing.T) {
101+
cb, err := channel.NewBuilder("sc://localhost")
102+
assert.NoError(t, err)
103+
assert.True(t, strings.Contains(cb.UserAgent(), "_SPARK_CONNECT_GO"))
104+
assert.True(t, strings.Contains(cb.UserAgent(), "go/"))
105+
assert.True(t, strings.Contains(cb.UserAgent(), "spark/"))
106+
assert.True(t, strings.Contains(cb.UserAgent(), "os/"))
107+
108+
cb, err = channel.NewBuilder("sc://localhost/;user_agent=custom")
109+
assert.NoError(t, err)
110+
assert.True(t, strings.Contains(cb.UserAgent(), "custom"))
111+
assert.True(t, strings.Contains(cb.UserAgent(), "go/"))
112+
assert.True(t, strings.Contains(cb.UserAgent(), "spark/"))
113+
assert.True(t, strings.Contains(cb.UserAgent(), "os/"))
114+
}

spark/client/client.go

Lines changed: 52 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,9 @@ func (s *sparkConnectClientImpl) newExecutePlanRequest(plan *proto.Plan) *proto.
5656
SessionId: s.sessionId,
5757
Plan: plan,
5858
UserContext: &proto.UserContext{
59-
UserId: "na",
59+
UserId: s.opts.UserId,
6060
},
61+
ClientType: &s.opts.UserAgent,
6162
// Operation ID is needed for being able to reattach.
6263
OperationId: &operationId,
6364
RequestOptions: []*proto.ExecutePlanRequest_RequestOption{
@@ -109,16 +110,22 @@ func (s *sparkConnectClientImpl) ExecutePlan(ctx context.Context, plan *proto.Pl
109110
return NewExecuteResponseStream(c, s.sessionId, *request.OperationId, s.opts), nil
110111
}
111112

112-
func (s *sparkConnectClientImpl) AnalyzePlan(ctx context.Context, plan *proto.Plan) (*proto.AnalyzePlanResponse, error) {
113-
request := proto.AnalyzePlanRequest{
113+
// Creates a new AnalyzePlanRequest with the necessary metadata.
114+
func (s *sparkConnectClientImpl) newAnalyzePlanStub() proto.AnalyzePlanRequest {
115+
return proto.AnalyzePlanRequest{
114116
SessionId: s.sessionId,
115-
Analyze: &proto.AnalyzePlanRequest_Schema_{
116-
Schema: &proto.AnalyzePlanRequest_Schema{
117-
Plan: plan,
118-
},
119-
},
120117
UserContext: &proto.UserContext{
121-
UserId: "na",
118+
UserId: s.opts.UserId,
119+
},
120+
ClientType: &s.opts.UserAgent,
121+
}
122+
}
123+
124+
func (s *sparkConnectClientImpl) AnalyzePlan(ctx context.Context, plan *proto.Plan) (*proto.AnalyzePlanResponse, error) {
125+
request := s.newAnalyzePlanStub()
126+
request.Analyze = &proto.AnalyzePlanRequest_Schema_{
127+
Schema: &proto.AnalyzePlanRequest_Schema{
128+
Plan: plan,
122129
},
123130
}
124131
// Append the other items to the request.
@@ -149,17 +156,11 @@ func (s *sparkConnectClientImpl) Explain(ctx context.Context, plan *proto.Plan,
149156
return nil, sparkerrors.WithType(fmt.Errorf("unsupported explain mode %v",
150157
explainMode), sparkerrors.InvalidArgumentError)
151158
}
152-
153-
request := proto.AnalyzePlanRequest{
154-
SessionId: s.sessionId,
155-
Analyze: &proto.AnalyzePlanRequest_Explain_{
156-
Explain: &proto.AnalyzePlanRequest_Explain{
157-
Plan: plan,
158-
ExplainMode: mode,
159-
},
160-
},
161-
UserContext: &proto.UserContext{
162-
UserId: "na",
159+
request := s.newAnalyzePlanStub()
160+
request.Analyze = &proto.AnalyzePlanRequest_Explain_{
161+
Explain: &proto.AnalyzePlanRequest_Explain{
162+
Plan: plan,
163+
ExplainMode: mode,
163164
},
164165
}
165166
// Append the other items to the request.
@@ -174,17 +175,11 @@ func (s *sparkConnectClientImpl) Explain(ctx context.Context, plan *proto.Plan,
174175

175176
func (s *sparkConnectClientImpl) Persist(ctx context.Context, plan *proto.Plan, storageLevel utils.StorageLevel) error {
176177
protoLevel := utils.ToProtoStorageLevel(storageLevel)
177-
178-
request := proto.AnalyzePlanRequest{
179-
SessionId: s.sessionId,
180-
Analyze: &proto.AnalyzePlanRequest_Persist_{
181-
Persist: &proto.AnalyzePlanRequest_Persist{
182-
Relation: plan.GetRoot(),
183-
StorageLevel: protoLevel,
184-
},
185-
},
186-
UserContext: &proto.UserContext{
187-
UserId: "na",
178+
request := s.newAnalyzePlanStub()
179+
request.Analyze = &proto.AnalyzePlanRequest_Persist_{
180+
Persist: &proto.AnalyzePlanRequest_Persist{
181+
Relation: plan.GetRoot(),
182+
StorageLevel: protoLevel,
188183
},
189184
}
190185
// Append the other items to the request.
@@ -198,15 +193,10 @@ func (s *sparkConnectClientImpl) Persist(ctx context.Context, plan *proto.Plan,
198193
}
199194

200195
func (s *sparkConnectClientImpl) Unpersist(ctx context.Context, plan *proto.Plan) error {
201-
request := proto.AnalyzePlanRequest{
202-
SessionId: s.sessionId,
203-
Analyze: &proto.AnalyzePlanRequest_Unpersist_{
204-
Unpersist: &proto.AnalyzePlanRequest_Unpersist{
205-
Relation: plan.GetRoot(),
206-
},
207-
},
208-
UserContext: &proto.UserContext{
209-
UserId: "na",
196+
request := s.newAnalyzePlanStub()
197+
request.Analyze = &proto.AnalyzePlanRequest_Unpersist_{
198+
Unpersist: &proto.AnalyzePlanRequest_Unpersist{
199+
Relation: plan.GetRoot(),
210200
},
211201
}
212202
// Append the other items to the request.
@@ -220,15 +210,10 @@ func (s *sparkConnectClientImpl) Unpersist(ctx context.Context, plan *proto.Plan
220210
}
221211

222212
func (s *sparkConnectClientImpl) GetStorageLevel(ctx context.Context, plan *proto.Plan) (*utils.StorageLevel, error) {
223-
request := proto.AnalyzePlanRequest{
224-
SessionId: s.sessionId,
225-
Analyze: &proto.AnalyzePlanRequest_GetStorageLevel_{
226-
GetStorageLevel: &proto.AnalyzePlanRequest_GetStorageLevel{
227-
Relation: plan.GetRoot(),
228-
},
229-
},
230-
UserContext: &proto.UserContext{
231-
UserId: "na",
213+
request := s.newAnalyzePlanStub()
214+
request.Analyze = &proto.AnalyzePlanRequest_GetStorageLevel_{
215+
GetStorageLevel: &proto.AnalyzePlanRequest_GetStorageLevel{
216+
Relation: plan.GetRoot(),
232217
},
233218
}
234219
// Append the other items to the request.
@@ -245,14 +230,9 @@ func (s *sparkConnectClientImpl) GetStorageLevel(ctx context.Context, plan *prot
245230
}
246231

247232
func (s *sparkConnectClientImpl) SparkVersion(ctx context.Context) (string, error) {
248-
request := proto.AnalyzePlanRequest{
249-
SessionId: s.sessionId,
250-
Analyze: &proto.AnalyzePlanRequest_SparkVersion_{
251-
SparkVersion: &proto.AnalyzePlanRequest_SparkVersion{},
252-
},
253-
UserContext: &proto.UserContext{
254-
UserId: "na",
255-
},
233+
request := s.newAnalyzePlanStub()
234+
request.Analyze = &proto.AnalyzePlanRequest_SparkVersion_{
235+
SparkVersion: &proto.AnalyzePlanRequest_SparkVersion{},
256236
}
257237
// Append the other items to the request.
258238
ctx = metadata.NewOutgoingContext(ctx, s.metadata)
@@ -265,15 +245,10 @@ func (s *sparkConnectClientImpl) SparkVersion(ctx context.Context) (string, erro
265245
}
266246

267247
func (s *sparkConnectClientImpl) DDLParse(ctx context.Context, sql string) (*types.StructType, error) {
268-
request := proto.AnalyzePlanRequest{
269-
SessionId: s.sessionId,
270-
Analyze: &proto.AnalyzePlanRequest_DdlParse{
271-
DdlParse: &proto.AnalyzePlanRequest_DDLParse{
272-
DdlString: sql,
273-
},
274-
},
275-
UserContext: &proto.UserContext{
276-
UserId: "na",
248+
request := s.newAnalyzePlanStub()
249+
request.Analyze = &proto.AnalyzePlanRequest_DdlParse{
250+
DdlParse: &proto.AnalyzePlanRequest_DDLParse{
251+
DdlString: sql,
277252
},
278253
}
279254
// Append the other items to the request.
@@ -287,16 +262,11 @@ func (s *sparkConnectClientImpl) DDLParse(ctx context.Context, sql string) (*typ
287262
}
288263

289264
func (s *sparkConnectClientImpl) SameSemantics(ctx context.Context, plan1 *proto.Plan, plan2 *proto.Plan) (bool, error) {
290-
request := proto.AnalyzePlanRequest{
291-
SessionId: s.sessionId,
292-
Analyze: &proto.AnalyzePlanRequest_SameSemantics_{
293-
SameSemantics: &proto.AnalyzePlanRequest_SameSemantics{
294-
TargetPlan: plan1,
295-
OtherPlan: plan2,
296-
},
297-
},
298-
UserContext: &proto.UserContext{
299-
UserId: "na",
265+
request := s.newAnalyzePlanStub()
266+
request.Analyze = &proto.AnalyzePlanRequest_SameSemantics_{
267+
SameSemantics: &proto.AnalyzePlanRequest_SameSemantics{
268+
TargetPlan: plan1,
269+
OtherPlan: plan2,
300270
},
301271
}
302272
// Append the other items to the request.
@@ -310,15 +280,10 @@ func (s *sparkConnectClientImpl) SameSemantics(ctx context.Context, plan1 *proto
310280
}
311281

312282
func (s *sparkConnectClientImpl) SemanticHash(ctx context.Context, plan *proto.Plan) (int32, error) {
313-
request := proto.AnalyzePlanRequest{
314-
SessionId: s.sessionId,
315-
Analyze: &proto.AnalyzePlanRequest_SemanticHash_{
316-
SemanticHash: &proto.AnalyzePlanRequest_SemanticHash{
317-
Plan: plan,
318-
},
319-
},
320-
UserContext: &proto.UserContext{
321-
UserId: "na",
283+
request := s.newAnalyzePlanStub()
284+
request.Analyze = &proto.AnalyzePlanRequest_SemanticHash_{
285+
SemanticHash: &proto.AnalyzePlanRequest_SemanticHash{
286+
Plan: plan,
322287
},
323288
}
324289
// Append the other items to the request.
@@ -337,8 +302,9 @@ func (s *sparkConnectClientImpl) Config(ctx context.Context,
337302
request := &proto.ConfigRequest{
338303
Operation: operation,
339304
UserContext: &proto.UserContext{
340-
UserId: "na",
305+
UserId: s.opts.UserId,
341306
},
307+
ClientType: &s.opts.UserAgent,
342308
}
343309
request.SessionId = s.sessionId
344310
resp, err := s.client.Config(ctx, request)

spark/client/options/options.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,16 @@ package options
1717

1818
type SparkClientOptions struct {
1919
ReattachExecution bool
20+
UserAgent string
21+
UserId string
2022
}
2123

2224
var DefaultSparkClientOptions = SparkClientOptions{
2325
ReattachExecution: false,
2426
}
27+
28+
func NewSparkClientOptions(reattach bool) SparkClientOptions {
29+
return SparkClientOptions{
30+
ReattachExecution: reattach,
31+
}
32+
}

spark/sql/sparksession.go

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,9 +93,15 @@ func (s *SparkSessionBuilder) Build(ctx context.Context) (SparkSession, error) {
9393
}
9494

9595
sessionId := uuid.NewString()
96+
97+
// Update the options according to the configuration.
98+
opts := options.NewSparkClientOptions(options.DefaultSparkClientOptions.ReattachExecution)
99+
opts.UserAgent = s.channelBuilder.UserAgent()
100+
opts.UserId = s.channelBuilder.User()
101+
96102
return &sparkSessionImpl{
97103
sessionId: sessionId,
98-
client: client.NewSparkExecutor(conn, meta, sessionId, options.DefaultSparkClientOptions),
104+
client: client.NewSparkExecutor(conn, meta, sessionId, opts),
99105
}, nil
100106
}
101107

spark/version.go

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one or more
2+
// contributor license agreements. See the NOTICE file distributed with
3+
// this work for additional information regarding copyright ownership.
4+
// The ASF licenses this file to You under the Apache License, Version 2.0
5+
// (the "License"); you may not use this file except in compliance with
6+
// the License. You may obtain a copy of the License at
7+
//
8+
// http://www.apache.org/licenses/LICENSE-2.0
9+
//
10+
// Unless required by applicable law or agreed to in writing, software
11+
// distributed under the License is distributed on an "AS IS" BASIS,
12+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
// See the License for the specific language governing permissions and
14+
// limitations under the License.
15+
16+
package spark
17+
18+
func Version() string {
19+
return "3.5.x"
20+
}

0 commit comments

Comments
 (0)