Skip to content

Commit 833b62c

Browse files
committed
Wire stuff, remove String payload
1 parent 3a7438c commit 833b62c

File tree

7 files changed

+212
-60
lines changed

7 files changed

+212
-60
lines changed

core/src/jsonrpclib/Codec.scala

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,28 @@
11
package jsonrpclib
22

33
import com.github.plokhotnyuk.jsoniter_scala.core._
4-
import jsonrpclib.Payload.BytesPayload
5-
import jsonrpclib.Payload.StringPayload
4+
import jsonrpclib.Payload
65

76
trait Codec[A] {
87

9-
def encodeBytes(a: A): Payload.BytesPayload
10-
def encodeString(a: A): Payload.StringPayload
8+
def encode(a: A): Payload
119
def decode(payload: Option[Payload]): Either[ProtocolError, A]
1210

1311
}
1412

1513
object Codec {
1614

17-
def encodeBytes[A](a: A)(implicit codec: Codec[A]): Payload.BytesPayload = codec.encodeBytes(a)
18-
def encodeString[A](a: A)(implicit codec: Codec[A]): Payload.StringPayload = codec.encodeString(a)
15+
def encode[A](a: A)(implicit codec: Codec[A]): Payload = codec.encode(a)
1916
def decode[A](payload: Option[Payload])(implicit codec: Codec[A]): Either[ProtocolError, A] = codec.decode(payload)
2017

2118
implicit def fromJsonCodec[A](implicit jsonCodec: JsonValueCodec[A]): Codec[A] = new Codec[A] {
22-
def encodeBytes(a: A): Payload.BytesPayload = Payload.BytesPayload(writeToArray(a))
23-
24-
def encodeString(a: A): Payload.StringPayload = Payload.StringPayload(writeToString(a))
19+
def encode(a: A): Payload = Payload(writeToArray(a))
2520

2621
def decode(payload: Option[Payload]): Either[ProtocolError, A] = {
2722
try {
2823
payload match {
29-
case Some(BytesPayload(array)) => Right(readFromArray(array))
30-
case Some(StringPayload(str)) => Right(readFromString(str))
31-
case None => Left(ProtocolError.ParseError("Expected to decode a payload"))
24+
case Some(Payload(array)) => Right(readFromArray(array))
25+
case None => Left(ProtocolError.ParseError("Expected to decode a payload"))
3226
}
3327
} catch { case e: JsonReaderException => Left(ProtocolError.ParseError(e.getMessage())) }
3428
}

core/src/jsonrpclib/Payload.scala

Lines changed: 22 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -5,36 +5,34 @@ import com.github.plokhotnyuk.jsoniter_scala.core.JsonValueCodec
55
import com.github.plokhotnyuk.jsoniter_scala.core.JsonReader
66
import com.github.plokhotnyuk.jsoniter_scala.core.JsonWriter
77

8-
sealed trait Payload
9-
object Payload {
10-
final case class StringPayload(str: String) extends Payload
11-
final case class BytesPayload(array: Array[Byte]) extends Payload {
12-
override def equals(other: Any) = other match {
13-
case bytes: BytesPayload => java.util.Arrays.equals(array, bytes.array)
14-
case _ => false
15-
}
8+
final case class Payload(array: Array[Byte]) {
9+
override def equals(other: Any) = other match {
10+
case bytes: Payload => java.util.Arrays.equals(array, bytes.array)
11+
case _ => false
12+
}
1613

17-
override def hashCode(): Int = {
18-
var hashCode = 0
19-
var i = 0
20-
while (i < array.length) {
21-
hashCode += array(i).hashCode()
22-
i += 1
23-
}
24-
hashCode
14+
override def hashCode(): Int = {
15+
var hashCode = 0
16+
var i = 0
17+
while (i < array.length) {
18+
hashCode += array(i).hashCode()
19+
i += 1
2520
}
26-
27-
override def toString = Base64.getEncoder().encodeToString(array)
21+
hashCode
2822
}
2923

24+
override def toString = Base64.getEncoder().encodeToString(array)
25+
}
26+
object Payload {
27+
3028
implicit val payloadJsonValueCodec: JsonValueCodec[Payload] = new JsonValueCodec[Payload] {
3129
def decodeValue(in: JsonReader, default: Payload): Payload = {
32-
Payload.BytesPayload(in.readRawValAsBytes())
33-
}
34-
def encodeValue(x: Payload, out: JsonWriter): Unit = x match {
35-
case StringPayload(str) => out.writeRawVal(str.getBytes())
36-
case BytesPayload(array) => out.writeRawVal(array)
30+
Payload(in.readRawValAsBytes())
3731
}
38-
def nullValue: StringPayload = null
32+
33+
def encodeValue(bytes: Payload, out: JsonWriter): Unit =
34+
out.writeRawVal(bytes.array)
35+
36+
def nullValue: Payload = null
3937
}
4038
}

core/src/jsonrpclib/internals/JsonRPCChannel.scala renamed to core/src/jsonrpclib/internals/FutureBaseChannel.scala

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@ package jsonrpclib
33
import scala.concurrent.ExecutionContext
44
import scala.concurrent.Future
55
import jsonrpclib.internals._
6-
import jsonrpclib.Payload.BytesPayload
7-
import jsonrpclib.Payload.StringPayload
86
import scala.concurrent.Promise
97
import java.util.concurrent.atomic.AtomicLong
108
import jsonrpclib.Endpoint.NotificationEndpoint
@@ -28,8 +26,7 @@ abstract class FutureBasedChannel(endpoints: List[Endpoint[Future]])(implicit ec
2826
protected def getEndpoint(method: String): Future[Option[Endpoint[Future]]] =
2927
Future.successful(endpointsMap.get(method))
3028
protected def sendMessage(message: Message): Future[Unit] = {
31-
sendPayload(Codec.encodeBytes(message))
32-
Future.successful(())
29+
sendPayload(Codec.encode(message)).map(_ => ())
3330
}
3431
protected def nextCallId(): Future[CallId] = Future.successful(CallId.NumberId(nextID.incrementAndGet()))
3532

core/src/jsonrpclib/internals/MessageDispatcher.scala

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@ package internals
44
import scala.concurrent.ExecutionContext
55
import scala.concurrent.Future
66
import jsonrpclib.internals._
7-
import jsonrpclib.Payload.BytesPayload
8-
import jsonrpclib.Payload.StringPayload
97
import scala.concurrent.Promise
108
import java.util.concurrent.atomic.AtomicLong
119
import jsonrpclib.Endpoint.NotificationEndpoint
@@ -30,7 +28,7 @@ private[jsonrpclib] abstract class MessageDispatcher[F[_]](implicit F: Monadic[F
3028
protected def removePendingCall(callId: CallId): F[Option[OutputMessage => F[Unit]]]
3129

3230
def notificationStub[In](method: String)(implicit inCodec: Codec[In]): In => F[Unit] = { (input: In) =>
33-
val encoded = inCodec.encodeBytes(input)
31+
val encoded = inCodec.encode(input)
3432
val message = InputMessage.NotificationMessage(method, Some(encoded))
3533
sendMessage(message)
3634
}
@@ -39,7 +37,7 @@ private[jsonrpclib] abstract class MessageDispatcher[F[_]](implicit F: Monadic[F
3937
method: String
4038
)(implicit inCodec: Codec[In], errCodec: ErrorCodec[Err], outCodec: Codec[Out]): In => F[Either[Err, Out]] = {
4139
(input: In) =>
42-
val encoded = inCodec.encodeBytes(input)
40+
val encoded = inCodec.encode(input)
4341
doFlatMap(nextCallId()) { callId =>
4442
val message = InputMessage.RequestMessage(method, callId, Some(encoded))
4543
doFlatMap(createPromise[Either[Err, Out]]()) { case (fulfill, future) =>
@@ -93,7 +91,7 @@ private[jsonrpclib] abstract class MessageDispatcher[F[_]](implicit F: Monadic[F
9391
case Right(value) =>
9492
doFlatMap(ep.run(value)) {
9593
case Right(data) =>
96-
val responseData = ep.outCodec.encodeBytes(data)
94+
val responseData = ep.outCodec.encode(data)
9795
sendMessage(OutputMessage.ResponseMessage(callId, responseData))
9896
case Left(error) =>
9997
val errorPayload = ep.errCodec.encode(error)

fs2/src/jsonrpclib/fs2/FS2Channel.scala

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@ package fs2interop
33

44
import jsonrpclib.internals.MessageDispatcher
55

6-
import fs2._
6+
import _root_.fs2.Stream
7+
import _root_.fs2.Pipe
78
import jsonrpclib.internals._
89
import scala.util.Try
910
import cats.Monad
@@ -15,12 +16,10 @@ import scala.util.Success
1516
import cats.Applicative
1617
import cats.data.Kleisli
1718
import cats.MonadThrow
18-
import jsonrpclib.StubTemplate.NotificationTemplate
19-
import jsonrpclib.StubTemplate.RequestResponseTemplate
19+
import jsonrpclib.StubTemplate._
2020
import cats.Defer
21-
import jsonrpclib.internals.OutputMessage.ErrorMessage
22-
import jsonrpclib.internals.OutputMessage.ResponseMessage
2321
import cats.Functor
22+
import jsonrpclib.internals.OutputMessage._
2423
import cats.effect.std.syntax.supervisor
2524
import cats.effect.std.Supervisor
2625

@@ -34,15 +33,26 @@ trait FS2Channel[F[_]] extends Channel[F] {
3433

3534
object FS2Channel {
3635

36+
def lspCompliant[F[_]: Concurrent](
37+
byteStream: Stream[F, Byte],
38+
byteSink: Pipe[F, Byte, Nothing],
39+
startingEndpoints: List[Endpoint[F]] = List.empty,
40+
bufferSize: Int = 512
41+
): Resource[F, FS2Channel[F]] = internals.LSP.writeSink(byteSink, bufferSize).flatMap { sink =>
42+
apply[F](internals.LSP.readStream(byteStream), sink, startingEndpoints)
43+
}
44+
3745
def apply[F[_]: Concurrent](
38-
inputStream: fs2.Stream[F, Payload],
39-
outputPipe: Payload => F[Unit]
46+
payloadStream: Stream[F, Payload],
47+
payloadSink: Payload => F[Unit],
48+
startingEndpoints: List[Endpoint[F]] = List.empty
4049
): Resource[F, FS2Channel[F]] = {
50+
val endpointsMap = startingEndpoints.map(ep => ep.method -> ep).toMap
4151
for {
4252
supervisor <- Supervisor[F]
43-
ref <- Ref[F].of(State[F](Map.empty, Map.empty, 0)).toResource
44-
impl = new Impl(outputPipe, ref, supervisor)
45-
_ <- inputStream.evalMap(impl.handleReceivedPayload).compile.drain.background
53+
ref <- Ref[F].of(State[F](Map.empty, endpointsMap, 0)).toResource
54+
impl = new Impl(payloadSink, ref, supervisor)
55+
_ <- payloadStream.evalMap(impl.handleReceivedPayload).compile.drain.background
4656
} yield impl
4757
}
4858

@@ -86,7 +96,7 @@ object FS2Channel {
8696
protected def background[A](fa: F[A]): F[Unit] = supervisor.supervise(fa).void
8797
protected def reportError(params: Option[Payload], error: ProtocolError, method: String): F[Unit] = ???
8898
protected def getEndpoint(method: String): F[Option[Endpoint[F]]] = state.get.map(_.endpoints.get(method))
89-
protected def sendMessage(message: Message): F[Unit] = sink(Codec.encodeBytes(message))
99+
protected def sendMessage(message: Message): F[Unit] = sink(Codec.encode(message))
90100
protected def nextCallId(): F[CallId] = state.modify(_.nextCallId)
91101
protected def createPromise[A](): F[(Try[A] => F[Unit], () => F[A])] = Deferred[F, Try[A]].map { promise =>
92102
def compile(trya: Try[A]): F[Unit] = promise.complete(trya).void
Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
package jsonrpclib.fs2interop.internals
2+
3+
import fs2.Chunk
4+
import fs2.Stream
5+
import java.nio.charset.Charset
6+
import java.nio.charset.StandardCharsets
7+
import jsonrpclib.Payload
8+
import cats.MonadThrow
9+
import cats.effect.std.Queue
10+
import cats.effect.Concurrent
11+
import cats.implicits._
12+
import cats.effect.implicits._
13+
import cats.effect.kernel.Resource
14+
15+
object LSP {
16+
17+
def writeSink[F[_]: Concurrent](
18+
writePipe: fs2.Pipe[F, Byte, Nothing],
19+
bufferSize: Int
20+
): Resource[F, Payload => F[Unit]] =
21+
Queue.bounded[F, Payload](bufferSize).toResource.flatMap { queue =>
22+
val payloads = fs2.Stream.fromQueueUnterminated(queue, bufferSize)
23+
payloads.map(writeChunk).flatMap(Stream.chunk(_)).compile.drain.background.void.as(queue.offer(_))
24+
}
25+
26+
/** Split a stream of bytes into payloads by extracting each frame based on information contained in the headers.
27+
*
28+
* See https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification/#contentPart
29+
*/
30+
def readStream[F[_]: MonadThrow](bytes: Stream[F, Byte]): Stream[F, Payload] =
31+
bytes
32+
.scanChunks(ScanState.starting) { case (state, chunk) =>
33+
val (ns, maybeResult) = loop(state.concatChunk(chunk))
34+
(ns, Chunk(maybeResult))
35+
}
36+
.flatMap {
37+
case Right(acc) => Stream.iterable(acc).map(c => Payload(c.toArray))
38+
case Left(error) => Stream.raiseError[F](error)
39+
}
40+
41+
private def writeChunk(payload: Payload): Chunk[Byte] = {
42+
val size = payload.array.size
43+
val header = s"Content-Length: ${size}" + "\r\n" * 2
44+
Chunk.array(header.getBytes()) ++ Chunk.array(payload.array)
45+
}
46+
47+
private val returnByte = '\r'.toByte
48+
private val newlineByte = '\n'.toByte
49+
50+
private final case class LSPHeaders(
51+
contentLength: Int,
52+
mimeType: String,
53+
charset: Charset
54+
)
55+
56+
private final case class ParseError(message: String) extends Throwable {
57+
override def getMessage(): String = message
58+
}
59+
60+
private def parseHeader(
61+
line: String,
62+
headers: LSPHeaders
63+
): Either[ParseError, LSPHeaders] =
64+
line.trim() match {
65+
case s"Content-Length: ${integer(length)}" =>
66+
Right(headers.copy(contentLength = length))
67+
case s"Content-type: ${mimeType}; charset=${charset}" =>
68+
Right(
69+
headers.copy(mimeType = mimeType, charset = Charset.forName(charset))
70+
)
71+
case _ => Left(ParseError(s"Couldn't parse to header: $line"))
72+
}
73+
74+
private object integer {
75+
def unapply(string: String): Option[Int] = string.toIntOption
76+
}
77+
78+
private final case class ScanState(status: Status, currentHeaders: LSPHeaders, buffered: Chunk[Byte]) {
79+
def concatChunk(other: Chunk[Byte]) = copy(buffered = buffered ++ other)
80+
}
81+
82+
private object ScanState {
83+
def readingHeader(storedChunk: Chunk[Byte]) = ScanState(
84+
Status.ReadingHeader,
85+
LSPHeaders(-1, "application/json", StandardCharsets.UTF_8),
86+
storedChunk
87+
)
88+
89+
val starting: ScanState = readingHeader(Chunk.empty)
90+
}
91+
92+
private sealed trait Status
93+
94+
private object Status {
95+
case object ReadingHeader extends Status
96+
case object FinishedReadingHeader extends Status
97+
case object ReadingBody extends Status
98+
}
99+
100+
private def loop(
101+
state: ScanState,
102+
acc: Seq[Chunk[Byte]] = Seq.empty
103+
): (ScanState, Either[ParseError, Seq[Chunk[Byte]]]) =
104+
state match {
105+
case ScanState(Status.ReadingBody, headers, buffered) =>
106+
if (headers.contentLength <= buffered.size) {
107+
// We have a full payload to emit
108+
val (payload, tail) = buffered.splitAt(headers.contentLength)
109+
val newState = ScanState.readingHeader(tail)
110+
loop(newState, acc.appended(payload))
111+
} else {
112+
(state, Right(acc))
113+
}
114+
case ScanState(Status.ReadingHeader, headers, buffered) =>
115+
val bb = java.nio.ByteBuffer.allocate(buffered.size)
116+
val iterator = buffered.iterator
117+
var continue = true
118+
var newState: ScanState = null
119+
var error: ParseError = null
120+
while (iterator.hasNext && continue) {
121+
val byte = iterator.next
122+
if (byte == newlineByte) {
123+
parseHeader(new String(bb.array, StandardCharsets.US_ASCII), headers) match {
124+
case Right(newHeader) =>
125+
newState = ScanState(Status.FinishedReadingHeader, newHeader, Chunk.iterator(iterator))
126+
case Left(e) =>
127+
error = e
128+
}
129+
continue = false
130+
} else {
131+
bb.put(byte)
132+
}
133+
}
134+
if (newState != null) {
135+
loop(newState, acc)
136+
} else if (error != null) {
137+
(state, Left(error))
138+
} else {
139+
(state, Right(acc))
140+
}
141+
142+
case ScanState(Status.FinishedReadingHeader, headers, buffered) =>
143+
if (buffered.size >= 2) {
144+
if (buffered.startsWith(Seq(returnByte, newlineByte))) {
145+
// We have read two `\r\n` in a row, starting to scan a body
146+
loop(ScanState(Status.ReadingBody, headers, buffered.drop(2)), acc)
147+
} else {
148+
loop(ScanState(Status.ReadingHeader, headers, buffered), acc)
149+
}
150+
} else {
151+
(state, Right(acc))
152+
}
153+
}
154+
155+
}

fs2/test/src/jsonrpclib/fs2/FS2ChannelSpec.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@ object FS2ChannelSpec extends SimpleIOSuite {
3434
for {
3535
stdout <- Queue.bounded[IO, Payload](10).toResource
3636
stdin <- Queue.bounded[IO, Payload](10).toResource
37-
serverSideChannel <- FS2Channel[IO](fs2.Stream.fromQueueUnterminated(stdin), payload => stdout.offer(payload))
38-
clientSideChannel <- FS2Channel[IO](fs2.Stream.fromQueueUnterminated(stdout), payload => stdin.offer(payload))
37+
serverSideChannel <- FS2Channel[IO](fs2.Stream.fromQueueUnterminated(stdin), stdout.offer)
38+
clientSideChannel <- FS2Channel[IO](fs2.Stream.fromQueueUnterminated(stdout), stdin.offer)
3939
_ <- serverSideChannel.withEndpoint(endpoint)
4040
remoteFunction = clientSideChannel.simpleStub[IntWrapper, IntWrapper]("inc")
4141
result <- remoteFunction(IntWrapper(1)).toResource
@@ -49,8 +49,8 @@ object FS2ChannelSpec extends SimpleIOSuite {
4949
for {
5050
stdout <- Queue.bounded[IO, Payload](10).toResource
5151
stdin <- Queue.bounded[IO, Payload](10).toResource
52-
serverSideChannel <- FS2Channel[IO](fs2.Stream.fromQueueUnterminated(stdin), payload => stdout.offer(payload))
53-
clientSideChannel <- FS2Channel[IO](fs2.Stream.fromQueueUnterminated(stdout), payload => stdin.offer(payload))
52+
serverSideChannel <- FS2Channel[IO](fs2.Stream.fromQueueUnterminated(stdin), stdout.offer)
53+
clientSideChannel <- FS2Channel[IO](fs2.Stream.fromQueueUnterminated(stdout), stdin.offer)
5454
remoteFunction = clientSideChannel.simpleStub[IntWrapper, IntWrapper]("inc")
5555
result <- remoteFunction(IntWrapper(1)).attempt.toResource
5656
} yield {

0 commit comments

Comments
 (0)