Skip to content

Commit 0f9a587

Browse files
committed
Add SSL support..
SSL is disabled by default to avoid POLA violations. It is possible to enable and control SSL behavior via url parameters: - `sslmode=<mode>` enable ssl (prefer/require/verify-ca/verify-full [recommended]) - `sslrootcert=<path.pem>` specifies trusted certificates (JDK cacert if missing) Client certificate authentication is not implemented, due to lack of time and interest, but it should be easy to add.
1 parent c3747b5 commit 0f9a587

File tree

21 files changed

+364
-48
lines changed

21 files changed

+364
-48
lines changed

db-async-common/src/main/scala/com/github/mauricio/async/db/Configuration.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ object Configuration {
3737
* @param port database port, defaults to 5432
3838
* @param password password, defaults to no password
3939
* @param database database name, defaults to no database
40+
* @param ssl ssl configuration
4041
* @param charset charset for the connection, defaults to UTF-8, make sure you know what you are doing if you
4142
* change this
4243
* @param maximumMessageSize the maximum size a message from the server could possibly have, this limits possible
@@ -55,6 +56,7 @@ case class Configuration(username: String,
5556
port: Int = 5432,
5657
password: Option[String] = None,
5758
database: Option[String] = None,
59+
ssl: SSLConfiguration = SSLConfiguration(),
5860
charset: Charset = Configuration.DefaultCharset,
5961
maximumMessageSize: Int = 16777216,
6062
allocator: ByteBufAllocator = PooledByteBufAllocator.DEFAULT,
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
package com.github.mauricio.async.db
2+
3+
import java.io.File
4+
5+
import SSLConfiguration.Mode
6+
7+
/**
8+
*
9+
* Contains the SSL configuration necessary to connect to a database.
10+
*
11+
* @param mode whether and with what priority a SSL connection will be negotiated, default disabled
12+
* @param rootCert path to PEM encoded trusted root certificates, None to use internal JDK cacerts, defaults to None
13+
*
14+
*/
15+
case class SSLConfiguration(mode: Mode.Value = Mode.Disable, rootCert: Option[java.io.File] = None)
16+
17+
object SSLConfiguration {
18+
19+
object Mode extends Enumeration {
20+
val Disable = Value("disable") // only try a non-SSL connection
21+
val Prefer = Value("prefer") // first try an SSL connection; if that fails, try a non-SSL connection
22+
val Require = Value("require") // only try an SSL connection, but don't verify Certificate Authority
23+
val VerifyCA = Value("verify-ca") // only try an SSL connection, and verify that the server certificate is issued by a trusted certificate authority (CA)
24+
val VerifyFull = Value("verify-full") // only try an SSL connection, verify that the server certificate is issued by a trusted CA and that the server host name matches that in the certificate
25+
}
26+
27+
def apply(properties: Map[String, String]): SSLConfiguration = SSLConfiguration(
28+
mode = Mode.withName(properties.get("sslmode").getOrElse("disable")),
29+
rootCert = properties.get("sslrootcert").map(new File(_))
30+
)
31+
}

postgresql-async/src/main/scala/com/github/mauricio/async/db/postgresql/codec/MessageDecoder.scala

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
package com.github.mauricio.async.db.postgresql.codec
1818

1919
import com.github.mauricio.async.db.postgresql.exceptions.{MessageTooLongException}
20-
import com.github.mauricio.async.db.postgresql.messages.backend.ServerMessage
20+
import com.github.mauricio.async.db.postgresql.messages.backend.{ServerMessage, SSLResponseMessage}
2121
import com.github.mauricio.async.db.postgresql.parsers.{AuthenticationStartupParser, MessageParsersRegistry}
2222
import com.github.mauricio.async.db.util.{BufferDumper, Log}
2323
import java.nio.charset.Charset
@@ -31,15 +31,21 @@ object MessageDecoder {
3131
val DefaultMaximumSize = 16777216
3232
}
3333

34-
class MessageDecoder(charset: Charset, maximumMessageSize : Int = MessageDecoder.DefaultMaximumSize) extends ByteToMessageDecoder {
34+
class MessageDecoder(sslEnabled: Boolean, charset: Charset, maximumMessageSize : Int = MessageDecoder.DefaultMaximumSize) extends ByteToMessageDecoder {
3535

3636
import MessageDecoder.log
3737

3838
private val parser = new MessageParsersRegistry(charset)
3939

40+
private var sslChecked = false
41+
4042
override def decode(ctx: ChannelHandlerContext, b: ByteBuf, out: java.util.List[Object]): Unit = {
4143

42-
if (b.readableBytes() >= 5) {
44+
if (sslEnabled & !sslChecked) {
45+
val code = b.readByte()
46+
sslChecked = true
47+
out.add(new SSLResponseMessage(code == 'S'))
48+
} else if (b.readableBytes() >= 5) {
4349

4450
b.markReaderIndex()
4551

postgresql-async/src/main/scala/com/github/mauricio/async/db/postgresql/codec/MessageEncoder.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,12 +44,13 @@ class MessageEncoder(charset: Charset, encoderRegistry: ColumnEncoderRegistry) e
4444
override def encode(ctx: ChannelHandlerContext, msg: AnyRef, out: java.util.List[Object]) = {
4545

4646
val buffer = msg match {
47+
case SSLRequestMessage => SSLMessageEncoder.encode()
48+
case message: StartupMessage => startupEncoder.encode(message)
4749
case message: ClientMessage => {
4850
val encoder = (message.kind: @switch) match {
4951
case ServerMessage.Close => CloseMessageEncoder
5052
case ServerMessage.Execute => this.executeEncoder
5153
case ServerMessage.Parse => this.openEncoder
52-
case ServerMessage.Startup => this.startupEncoder
5354
case ServerMessage.Query => this.queryEncoder
5455
case ServerMessage.PasswordMessage => this.credentialEncoder
5556
case _ => throw new EncoderNotAvailableException(message)

postgresql-async/src/main/scala/com/github/mauricio/async/db/postgresql/codec/PostgreSQLConnectionHandler.scala

Lines changed: 57 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
package com.github.mauricio.async.db.postgresql.codec
1818

1919
import com.github.mauricio.async.db.Configuration
20+
import com.github.mauricio.async.db.SSLConfiguration.Mode
2021
import com.github.mauricio.async.db.column.{ColumnDecoderRegistry, ColumnEncoderRegistry}
2122
import com.github.mauricio.async.db.postgresql.exceptions._
2223
import com.github.mauricio.async.db.postgresql.messages.backend._
@@ -38,6 +39,12 @@ import com.github.mauricio.async.db.postgresql.messages.backend.RowDescriptionMe
3839
import com.github.mauricio.async.db.postgresql.messages.backend.ParameterStatusMessage
3940
import io.netty.channel.socket.nio.NioSocketChannel
4041
import io.netty.handler.codec.CodecException
42+
import io.netty.handler.ssl.{SslContextBuilder, SslHandler}
43+
import io.netty.handler.ssl.util.InsecureTrustManagerFactory
44+
import io.netty.util.concurrent.FutureListener
45+
import javax.net.ssl.{SSLParameters, TrustManagerFactory}
46+
import java.security.KeyStore
47+
import java.io.FileInputStream
4148

4249
object PostgreSQLConnectionHandler {
4350
final val log = Log.get[PostgreSQLConnectionHandler]
@@ -79,7 +86,7 @@ class PostgreSQLConnectionHandler
7986

8087
override def initChannel(ch: channel.Channel): Unit = {
8188
ch.pipeline.addLast(
82-
new MessageDecoder(configuration.charset, configuration.maximumMessageSize),
89+
new MessageDecoder(configuration.ssl.mode != Mode.Disable, configuration.charset, configuration.maximumMessageSize),
8390
new MessageEncoder(configuration.charset, encoderRegistry),
8491
PostgreSQLConnectionHandler.this)
8592
}
@@ -120,13 +127,61 @@ class PostgreSQLConnectionHandler
120127
}
121128

122129
override def channelActive(ctx: ChannelHandlerContext): Unit = {
123-
ctx.writeAndFlush(new StartupMessage(this.properties))
130+
if (configuration.ssl.mode == Mode.Disable)
131+
ctx.writeAndFlush(new StartupMessage(this.properties))
132+
else
133+
ctx.writeAndFlush(SSLRequestMessage)
124134
}
125135

126136
override def channelRead0(ctx: ChannelHandlerContext, msg: Object): Unit = {
127137

128138
msg match {
129139

140+
case SSLResponseMessage(supported) =>
141+
if (supported) {
142+
val ctxBuilder = SslContextBuilder.forClient()
143+
if (configuration.ssl.mode >= Mode.VerifyCA) {
144+
configuration.ssl.rootCert.fold {
145+
val tmf = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm())
146+
val ks = KeyStore.getInstance(KeyStore.getDefaultType())
147+
val cacerts = new FileInputStream(System.getProperty("java.home") + "/lib/security/cacerts")
148+
try {
149+
ks.load(cacerts, "changeit".toCharArray)
150+
} finally {
151+
cacerts.close()
152+
}
153+
tmf.init(ks)
154+
ctxBuilder.trustManager(tmf)
155+
} { path =>
156+
ctxBuilder.trustManager(path)
157+
}
158+
} else {
159+
ctxBuilder.trustManager(InsecureTrustManagerFactory.INSTANCE)
160+
}
161+
val sslContext = ctxBuilder.build()
162+
val sslEngine = sslContext.newEngine(ctx.alloc(), configuration.host, configuration.port)
163+
if (configuration.ssl.mode >= Mode.VerifyFull) {
164+
val sslParams = sslEngine.getSSLParameters()
165+
sslParams.setEndpointIdentificationAlgorithm("HTTPS")
166+
sslEngine.setSSLParameters(sslParams)
167+
}
168+
val handler = new SslHandler(sslEngine)
169+
ctx.pipeline().addFirst(handler)
170+
handler.handshakeFuture.addListener(new FutureListener[channel.Channel]() {
171+
def operationComplete(future: io.netty.util.concurrent.Future[channel.Channel]) {
172+
if (future.isSuccess()) {
173+
ctx.writeAndFlush(new StartupMessage(properties))
174+
} else {
175+
connectionDelegate.onError(future.cause())
176+
}
177+
}
178+
})
179+
} else if (configuration.ssl.mode < Mode.Require) {
180+
ctx.writeAndFlush(new StartupMessage(properties))
181+
} else {
182+
connectionDelegate.onError(new IllegalArgumentException("SSL is not supported on server"))
183+
}
184+
130185
case m: ServerMessage => {
131186

132187
(m.kind : @switch) match {
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
package com.github.mauricio.async.db.postgresql.encoders
2+
3+
import io.netty.buffer.ByteBuf
4+
import io.netty.buffer.Unpooled
5+
6+
object SSLMessageEncoder {
7+
8+
def encode(): ByteBuf = {
9+
val buffer = Unpooled.buffer()
10+
buffer.writeInt(8)
11+
buffer.writeShort(1234)
12+
buffer.writeShort(5679)
13+
buffer
14+
}
15+
16+
}

postgresql-async/src/main/scala/com/github/mauricio/async/db/postgresql/encoders/StartupMessageEncoder.scala

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,11 @@ import com.github.mauricio.async.db.util.ByteBufferUtils
2121
import java.nio.charset.Charset
2222
import io.netty.buffer.{Unpooled, ByteBuf}
2323

24-
class StartupMessageEncoder(charset: Charset) extends Encoder {
24+
class StartupMessageEncoder(charset: Charset) {
2525

2626
//private val log = Log.getByName("StartupMessageEncoder")
2727

28-
override def encode(message: ClientMessage): ByteBuf = {
29-
30-
val startup = message.asInstanceOf[StartupMessage]
28+
def encode(startup: StartupMessage): ByteBuf = {
3129

3230
val buffer = Unpooled.buffer()
3331
buffer.writeInt(0)
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
package com.github.mauricio.async.db.postgresql.messages.backend
2+
3+
case class SSLResponseMessage(supported: Boolean)

postgresql-async/src/main/scala/com/github/mauricio/async/db/postgresql/messages/backend/ServerMessage.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@ object ServerMessage {
4343
final val Query = 'Q'
4444
final val RowDescription = 'T'
4545
final val ReadyForQuery = 'Z'
46-
final val Startup = '0'
4746
final val Sync = 'S'
4847
}
4948

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
package com.github.mauricio.async.db.postgresql.messages.frontend
2+
3+
trait InitialClientMessage

0 commit comments

Comments
 (0)