Skip to content

Commit 89eb9bd

Browse files
Add new message type for vars serialization
1 parent 6c60575 commit 89eb9bd

File tree

4 files changed

+90
-1
lines changed

4 files changed

+90
-1
lines changed

Diff for: src/main/kotlin/org/jetbrains/kotlinx/jupyter/message_types.kt

+17-1
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import kotlinx.serialization.json.decodeFromJsonElement
2323
import kotlinx.serialization.json.encodeToJsonElement
2424
import kotlinx.serialization.json.jsonObject
2525
import kotlinx.serialization.serializer
26+
import org.jetbrains.kotlinx.jupyter.compiler.util.SerializedVariablesState
2627
import org.jetbrains.kotlinx.jupyter.exceptions.ReplException
2728
import kotlin.reflect.KClass
2829
import kotlin.reflect.full.createType
@@ -86,7 +87,10 @@ enum class MessageType(val contentClass: KClass<out MessageContent>) {
8687
COMM_CLOSE(CommClose::class),
8788

8889
LIST_ERRORS_REQUEST(ListErrorsRequest::class),
89-
LIST_ERRORS_REPLY(ListErrorsReply::class);
90+
LIST_ERRORS_REPLY(ListErrorsReply::class),
91+
92+
SERIALIZATION_REQUEST(SerializationRequest::class),
93+
SERIALIZATION_REPLY(SerializationReply::class);
9094

9195
// TODO: add custom commands
9296
// this custom message should be supported on client-side. either JS or Idea Plugin
@@ -573,6 +577,18 @@ class ListErrorsReply(
573577
val errors: List<ScriptDiagnostic>
574578
) : MessageContent()
575579

580+
@Serializable
581+
class SerializationRequest(
582+
val cellId: Int,
583+
val descriptorsState: Map<String, SerializedVariablesState>
584+
) : MessageContent()
585+
586+
@Serializable
587+
class SerializationReply(
588+
val cellId: Int,
589+
val descriptorsState: Map<String, SerializedVariablesState> = emptyMap()
590+
) : MessageContent()
591+
576592
@Serializable(MessageDataSerializer::class)
577593
data class MessageData(
578594
val header: MessageHeader? = null,

Diff for: src/main/kotlin/org/jetbrains/kotlinx/jupyter/protocol.kt

+7
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,13 @@ fun JupyterConnection.Socket.shellMessagesHandler(msg: Message, repl: ReplForJup
319319
}
320320
}
321321
}
322+
is SerializationRequest -> {
323+
GlobalScope.launch(Dispatchers.Default) {
324+
repl.serializeVariables(content.cellId, content.descriptorsState) { result ->
325+
sendWrapped(msg, makeReplyMessage(msg, MessageType.SERIALIZATION_REPLY, content = result))
326+
}
327+
}
328+
}
322329
is IsCompleteRequest -> {
323330
// We are in console mode, so switch off all the loggers
324331
if (mainLoggerLevel() != Level.OFF) disableLogging()

Diff for: src/main/kotlin/org/jetbrains/kotlinx/jupyter/repl.kt

+23
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ import org.jetbrains.kotlinx.jupyter.compiler.ScriptImportsCollector
2727
import org.jetbrains.kotlinx.jupyter.compiler.util.Classpath
2828
import org.jetbrains.kotlinx.jupyter.compiler.util.EvaluatedSnippetMetadata
2929
import org.jetbrains.kotlinx.jupyter.compiler.util.SerializedCompiledScriptsData
30+
import org.jetbrains.kotlinx.jupyter.compiler.util.SerializedVariablesState
3031
import org.jetbrains.kotlinx.jupyter.config.catchAll
3132
import org.jetbrains.kotlinx.jupyter.config.getCompilationConfiguration
3233
import org.jetbrains.kotlinx.jupyter.dependencies.JupyterScriptDependenciesResolverImpl
@@ -122,6 +123,8 @@ interface ReplForJupyter {
122123

123124
suspend fun listErrors(code: Code, callback: (ListErrorsResult) -> Unit)
124125

126+
suspend fun serializeVariables(cellId: Int, descriptorsState: Map<String, SerializedVariablesState>, callback: (SerializationReply) -> Unit)
127+
125128
val homeDir: File?
126129

127130
val currentClasspath: Collection<String>
@@ -513,6 +516,20 @@ class ReplForJupyterImpl(
513516
return ListErrorsResult(args.code, errorsList)
514517
}
515518

519+
private val serializationQueue = LockQueue<SerializationReply, SerializationArgs>()
520+
override suspend fun serializeVariables(cellId: Int, descriptorsState: Map<String, SerializedVariablesState>, callback: (SerializationReply) -> Unit) {
521+
doWithLock(SerializationArgs(cellId, descriptorsState, callback), serializationQueue, SerializationReply(cellId, descriptorsState), ::doSerializeVariables)
522+
}
523+
524+
private fun doSerializeVariables(args: SerializationArgs): SerializationReply {
525+
val resultMap = mutableMapOf<String, SerializedVariablesState>()
526+
args.descriptorsState.forEach { (name, state) ->
527+
resultMap[name] = variablesSerializer.doIncrementalSerialization(args.cellId - 1, name, state)
528+
}
529+
return SerializationReply(args.cellId, resultMap)
530+
}
531+
532+
516533
private fun <T, Args : LockQueueArgs<T>> doWithLock(
517534
args: Args,
518535
queue: LockQueue<T, Args>,
@@ -545,6 +562,12 @@ class ReplForJupyterImpl(
545562
private data class ListErrorsArgs(val code: String, override val callback: (ListErrorsResult) -> Unit) :
546563
LockQueueArgs<ListErrorsResult>
547564

565+
private data class SerializationArgs(
566+
val cellId: Int,
567+
val descriptorsState: Map<String, SerializedVariablesState>,
568+
override val callback: (SerializationReply) -> Unit
569+
) : LockQueueArgs<SerializationReply>
570+
548571
@JvmInline
549572
private value class LockQueue<T, Args : LockQueueArgs<T>>(
550573
private val args: AtomicReference<Args?> = AtomicReference()

Diff for: src/test/kotlin/org/jetbrains/kotlinx/jupyter/test/repl/ReplTests.kt

+43
Original file line numberDiff line numberDiff line change
@@ -911,4 +911,47 @@ class ReplVarsSerializationTest : AbstractSingleReplTest() {
911911
assertEquals("${values++}", state.value)
912912
}
913913
}
914+
915+
@Test
916+
fun testSerializationMessage() {
917+
val res = eval(
918+
"""
919+
val x = listOf(listOf(1), listOf(2), listOf(3), listOf(4))
920+
""".trimIndent(),
921+
jupyterId = 1
922+
)
923+
val varsData = res.metadata.evaluatedVariablesState
924+
assertEquals(1, varsData.size)
925+
val listData = varsData["x"]!!
926+
assertTrue(listData.isContainer)
927+
val actualContainer = listData.fieldDescriptor.entries.first().value!!
928+
val propertyName = listData.fieldDescriptor.entries.first().key
929+
930+
runBlocking {
931+
repl.serializeVariables(1, mapOf(propertyName to actualContainer)) { result ->
932+
val data = result.descriptorsState
933+
assertTrue(data.isNotEmpty())
934+
935+
val innerList = data.entries.last().value!!
936+
assertTrue(innerList.isContainer)
937+
var receivedDescriptor = innerList.fieldDescriptor
938+
assertEquals(2, receivedDescriptor.size)
939+
receivedDescriptor = receivedDescriptor.entries.last().value!!.fieldDescriptor
940+
941+
assertEquals(5, receivedDescriptor.size)
942+
var values = 1
943+
receivedDescriptor.forEach { (name, state) ->
944+
if (name == "size") {
945+
assertFalse(state!!.isContainer)
946+
assertTrue(state!!.fieldDescriptor.isEmpty())
947+
return@forEach
948+
}
949+
val fieldDescriptor = state!!.fieldDescriptor
950+
assertEquals(0, fieldDescriptor.size)
951+
assertTrue(state.isContainer)
952+
assertEquals("${values++}", state.value)
953+
}
954+
}
955+
}
956+
}
914957
}

0 commit comments

Comments
 (0)