|
9 | 9 | "database/sql" |
10 | 10 | "database/sql/driver" |
11 | 11 | "encoding/binary" |
| 12 | +"encoding/json" |
12 | 13 | "errors" |
13 | 14 | "fmt" |
14 | 15 | "io" |
@@ -1143,6 +1144,10 @@ func isDriverSetting(key string) bool { |
1143 | 1144 | return true |
1144 | 1145 | case "password": |
1145 | 1146 | return true |
| 1147 | +case "oauth_token": |
| 1148 | +return true |
| 1149 | +case "oauth_token_file": |
| 1150 | +return true |
1146 | 1151 | case "sslmode", "sslcert", "sslkey", "sslrootcert", "sslinline", "sslsni": |
1147 | 1152 | return true |
1148 | 1153 | case "fallback_application_name": |
@@ -1290,59 +1295,135 @@ func (cn *conn) auth(r *readBuf, o values) { |
1290 | 1295 | // from the server.. |
1291 | 1296 |
|
1292 | 1297 | case 10: |
1293 | | -sc := scram.NewClient(sha256.New, o["user"], o["password"]) |
1294 | | -sc.Step(nil) |
1295 | | -if sc.Err() != nil { |
1296 | | -errorf("SCRAM-SHA-256 error: %s", sc.Err().Error()) |
| 1298 | +switch saslMethod := r.string(); saslMethod { |
| 1299 | +case "SCRAM-SHA-256": |
| 1300 | +cn.saslScram(o) |
| 1301 | +case "OAUTHBEARER": |
| 1302 | +cn.saslOAuth(o) |
1297 | 1303 | } |
1298 | | -scOut := sc.Out() |
1299 | 1304 |
|
1300 | | -w := cn.writeBuf('p') |
1301 | | -w.string("SCRAM-SHA-256") |
1302 | | -w.int32(len(scOut)) |
1303 | | -w.bytes(scOut) |
1304 | | -cn.send(w) |
| 1305 | +default: |
| 1306 | +errorf("unknown authentication response: %d", code) |
| 1307 | +} |
| 1308 | +} |
1305 | 1309 |
|
1306 | | -t, r := cn.recv() |
1307 | | -if t != 'R' { |
1308 | | -errorf("unexpected password response: %q", t) |
1309 | | -} |
| 1310 | +func (cn *conn) saslScram(o values) { |
| 1311 | +sc := scram.NewClient(sha256.New, o["user"], o["password"]) |
| 1312 | +sc.Step(nil) |
| 1313 | +if sc.Err() != nil { |
| 1314 | +errorf("SCRAM-SHA-256 error: %s", sc.Err().Error()) |
| 1315 | +} |
| 1316 | +scOut := sc.Out() |
1310 | 1317 |
|
1311 | | -if r.int32() != 11 { |
1312 | | -errorf("unexpected authentication response: %q", t) |
1313 | | -} |
| 1318 | +w := cn.writeBuf('p') |
| 1319 | +w.string("SCRAM-SHA-256") |
| 1320 | +w.int32(len(scOut)) |
| 1321 | +w.bytes(scOut) |
| 1322 | +cn.send(w) |
1314 | 1323 |
|
1315 | | -nextStep := r.next(len(*r)) |
1316 | | -sc.Step(nextStep) |
1317 | | -if sc.Err() != nil { |
1318 | | -errorf("SCRAM-SHA-256 error: %s", sc.Err().Error()) |
1319 | | -} |
| 1324 | +t, r := cn.recv() |
| 1325 | +if t != 'R' { |
| 1326 | +errorf("unexpected password response: %q", t) |
| 1327 | +} |
1320 | 1328 |
|
1321 | | -scOut = sc.Out() |
1322 | | -w = cn.writeBuf('p') |
1323 | | -w.bytes(scOut) |
1324 | | -cn.send(w) |
| 1329 | +if r.int32() != 11 { |
| 1330 | +errorf("unexpected authentication response: %q", t) |
| 1331 | +} |
1325 | 1332 |
|
1326 | | -t, r = cn.recv() |
1327 | | -if t != 'R' { |
1328 | | -errorf("unexpected password response: %q", t) |
1329 | | -} |
| 1333 | +nextStep := r.next(len(*r)) |
| 1334 | +sc.Step(nextStep) |
| 1335 | +if sc.Err() != nil { |
| 1336 | +errorf("SCRAM-SHA-256 error: %s", sc.Err().Error()) |
| 1337 | +} |
1330 | 1338 |
|
1331 | | -if r.int32() != 12 { |
1332 | | -errorf("unexpected authentication response: %q", t) |
| 1339 | +scOut = sc.Out() |
| 1340 | +w = cn.writeBuf('p') |
| 1341 | +w.bytes(scOut) |
| 1342 | +cn.send(w) |
| 1343 | + |
| 1344 | +t, r = cn.recv() |
| 1345 | +if t != 'R' { |
| 1346 | +errorf("unexpected password response: %q", t) |
| 1347 | +} |
| 1348 | + |
| 1349 | +if r.int32() != 12 { |
| 1350 | +errorf("unexpected authentication response: %q", t) |
| 1351 | +} |
| 1352 | + |
| 1353 | +nextStep = r.next(len(*r)) |
| 1354 | +sc.Step(nextStep) |
| 1355 | +if sc.Err() != nil { |
| 1356 | +errorf("SCRAM-SHA-256 error: %s", sc.Err().Error()) |
| 1357 | +} |
| 1358 | +} |
| 1359 | + |
| 1360 | +func (cn *conn) saslOAuth(o values) { |
| 1361 | +// https://www.rfc-editor.org/rfc/rfc7628.html#section-3.1 |
| 1362 | +w := cn.writeBuf('p') |
| 1363 | +w.string("OAUTHBEARER") |
| 1364 | + |
| 1365 | +token, err := getOAuthToken(o) |
| 1366 | +if err != nil { |
| 1367 | +errorf("failed to obtain oauth token: %s", err) |
| 1368 | +} |
| 1369 | +initialResponse := []byte("n,,\x01auth=Bearer " + token + "\x01\x01") |
| 1370 | +w.int32(len(initialResponse)) |
| 1371 | +w.bytes(initialResponse) |
| 1372 | +cn.send(w) |
| 1373 | + |
| 1374 | +t, r := cn.recv() |
| 1375 | +if t != 'R' { |
| 1376 | +errorf("unexpected oauth response: %q", t) |
| 1377 | +} |
| 1378 | + |
| 1379 | +if code := r.int32(); code != 0 { |
| 1380 | +// usually on an authentication error we should get a |
| 1381 | +// AuthenticationSASLContinue message |
| 1382 | +if code != 11 { |
| 1383 | +errorf("unexpected oauth response: %q %d", t, code) |
1333 | 1384 | } |
1334 | 1385 |
|
1335 | | -nextStep = r.next(len(*r)) |
1336 | | -sc.Step(nextStep) |
1337 | | -if sc.Err() != nil { |
1338 | | -errorf("SCRAM-SHA-256 error: %s", sc.Err().Error()) |
| 1386 | +// the AuthenticationSASLContinue does have an error payload |
| 1387 | +// https://www.rfc-editor.org/rfc/rfc7628.html#section-3.2.2 |
| 1388 | +errResponse := struct { |
| 1389 | +Status string `json:"status"` |
| 1390 | +Scope string `json:"scope"` |
| 1391 | +OpenIDConfiguration string `json:"openid-configuration"` |
| 1392 | +}{} |
| 1393 | +err := json.Unmarshal(*r, &errResponse) |
| 1394 | +if err != nil { |
| 1395 | +errorf("invalid oauth error response") |
1339 | 1396 | } |
1340 | 1397 |
|
1341 | | -default: |
1342 | | -errorf("unknown authentication response: %d", code) |
| 1398 | +errorf("oauth authentication failed '%s'", errResponse.Status) |
| 1399 | + |
| 1400 | +// https://www.rfc-editor.org/rfc/rfc7628.html#section-3.2.3 |
| 1401 | +// we deliberately don't complete the error messaging sequence as described |
| 1402 | +// in 3.2.3 as we're going to close the connection either way |
| 1403 | +// w = cn.writeBuf('p') |
| 1404 | +// w.int32(1) |
| 1405 | +// w.bytes([]byte{0x01}) |
| 1406 | +// cn.send(w) |
1343 | 1407 | } |
1344 | 1408 | } |
1345 | 1409 |
|
| 1410 | +func getOAuthToken(o values) (string, error) { |
| 1411 | +if token, ok := o["oauth_token"]; ok { |
| 1412 | +return token, nil |
| 1413 | +} |
| 1414 | + |
| 1415 | +if tokenFile, ok := o["oauth_token_file"]; ok { |
| 1416 | +rawToken, err := os.ReadFile(tokenFile) |
| 1417 | +if err != nil { |
| 1418 | +return "", err |
| 1419 | +} |
| 1420 | +rawToken = bytes.TrimSuffix(rawToken, []byte("\n")) |
| 1421 | +return string(rawToken), nil |
| 1422 | +} |
| 1423 | + |
| 1424 | +return "", fmt.Errorf("no oauth token configured") |
| 1425 | +} |
| 1426 | + |
1346 | 1427 | type format int |
1347 | 1428 |
|
1348 | 1429 | const formatText format = 0 |
|
0 commit comments