Skip to content

Commit 5aebd3d

Browse files
Make serialization contain comm_id to respect jupyter comm handlers
1 parent ec951f7 commit 5aebd3d

File tree

4 files changed

+19
-14
lines changed

4 files changed

+19
-14
lines changed

jupyter-lib/shared-compiler/src/main/kotlin/org/jetbrains/kotlinx/jupyter/compiler/util/serializedCompiledScript.kt

+3-2
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,9 @@ data class SerializedVariablesState(
5252

5353
@Serializable
5454
class SerializationReply(
55-
val cellId: Int = 1,
56-
val descriptorsState: Map<String, SerializedVariablesState> = emptyMap()
55+
val cell_id: Int = 1,
56+
val descriptorsState: Map<String, SerializedVariablesState> = emptyMap(),
57+
val comm_id: String = ""
5758
)
5859

5960
@Serializable

src/main/kotlin/org/jetbrains/kotlinx/jupyter/message_types.kt

+5-3
Original file line numberDiff line numberDiff line change
@@ -561,13 +561,15 @@ class SerializationRequest(
561561
val cellId: Int,
562562
val descriptorsState: Map<String, SerializedVariablesState>,
563563
val topLevelDescriptorName: String = "",
564-
val pathToDescriptor: List<String> = emptyList()
564+
val pathToDescriptor: List<String> = emptyList(),
565+
val commId: String = ""
565566
) : MessageContent()
566567

567568
@Serializable
568569
class SerializationReply(
569-
val cellId: Int = 1,
570-
val descriptorsState: Map<String, SerializedVariablesState> = emptyMap()
570+
val cell_id: Int = 1,
571+
val descriptorsState: Map<String, SerializedVariablesState> = emptyMap(),
572+
val comm_id: String = ""
571573
) : MessageContent()
572574

573575
@Serializable(MessageDataSerializer::class)

src/main/kotlin/org/jetbrains/kotlinx/jupyter/protocol.kt

+6-4
Original file line numberDiff line numberDiff line change
@@ -308,21 +308,23 @@ fun JupyterConnection.Socket.shellMessagesHandler(msg: Message, repl: ReplForJup
308308
sendWrapped(msg, makeReplyMessage(msg, MessageType.COMM_INFO_REPLY, content = CommInfoReply(mapOf())))
309309
}
310310
is CommOpen -> {
311-
if (!content.commId.equals(MessageType.SERIALIZATION_REQUEST.name, ignoreCase = true)) {
311+
if (!content.targetName.equals("kotlin_serialization", ignoreCase = true)) {
312312
send(makeReplyMessage(msg, MessageType.NONE))
313313
return
314314
}
315315
log.debug("Message type in CommOpen: $msg, ${msg.type}")
316316
val data = content.data ?: return sendWrapped(msg, makeReplyMessage(msg, MessageType.SERIALIZATION_REPLY))
317-
317+
if (data.isEmpty()) return sendWrapped(msg, makeReplyMessage(msg, MessageType.SERIALIZATION_REPLY))
318+
log.debug("Message data: $data")
318319
val messageContent = getVariablesDescriptorsFromJson(data)
319320
GlobalScope.launch(Dispatchers.Default) {
320321
repl.serializeVariables(
321322
messageContent.topLevelDescriptorName,
322323
messageContent.descriptorsState,
324+
content.commId,
323325
messageContent.pathToDescriptor
324326
) { result ->
325-
sendWrapped(msg, makeReplyMessage(msg, MessageType.COMM_OPEN, content = result))
327+
sendWrapped(msg, makeReplyMessage(msg, MessageType.COMM_MSG, content = result))
326328
}
327329
}
328330
}
@@ -343,7 +345,7 @@ fun JupyterConnection.Socket.shellMessagesHandler(msg: Message, repl: ReplForJup
343345
is SerializationRequest -> {
344346
GlobalScope.launch(Dispatchers.Default) {
345347
if (content.topLevelDescriptorName.isNotEmpty()) {
346-
repl.serializeVariables(content.topLevelDescriptorName, content.descriptorsState, content.pathToDescriptor) { result ->
348+
repl.serializeVariables(content.topLevelDescriptorName, content.descriptorsState, commID = content.commId, content.pathToDescriptor) { result ->
347349
sendWrapped(msg, makeReplyMessage(msg, MessageType.SERIALIZATION_REPLY, content = result))
348350
}
349351
} else {

src/main/kotlin/org/jetbrains/kotlinx/jupyter/repl.kt

+5-5
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ interface ReplForJupyter {
142142

143143
suspend fun serializeVariables(cellId: Int, topLevelVarName: String, descriptorsState: Map<String, SerializedVariablesState>, callback: (SerializationReply) -> Unit)
144144

145-
suspend fun serializeVariables(topLevelVarName: String, descriptorsState: Map<String, SerializedVariablesState>, pathToDescriptor: List<String> = emptyList(),
145+
suspend fun serializeVariables(topLevelVarName: String, descriptorsState: Map<String, SerializedVariablesState>, commID: String = "", pathToDescriptor: List<String> = emptyList(),
146146
callback: (SerializationReply) -> Unit)
147147

148148
val homeDir: File?
@@ -576,9 +576,8 @@ class ReplForJupyterImpl(
576576
doWithLock(SerializationArgs(descriptorsState, cellId = cellId, topLevelVarName = topLevelVarName, callback = callback), serializationQueue, SerializationReply(cellId, descriptorsState), ::doSerializeVariables)
577577
}
578578

579-
override suspend fun serializeVariables(topLevelVarName: String, descriptorsState: Map<String, SerializedVariablesState>, pathToDescriptor: List<String>,
580-
callback: (SerializationReply) -> Unit) {
581-
doWithLock(SerializationArgs(descriptorsState, topLevelVarName = topLevelVarName, callback = callback, pathToDescriptor = pathToDescriptor), serializationQueue, SerializationReply(), ::doSerializeVariables)
579+
override suspend fun serializeVariables(topLevelVarName: String, descriptorsState: Map<String, SerializedVariablesState>, commID: String, pathToDescriptor: List<String>, callback: (SerializationReply) -> Unit) {
580+
doWithLock(SerializationArgs(descriptorsState, topLevelVarName = topLevelVarName, callback = callback, comm_id = commID ,pathToDescriptor = pathToDescriptor), serializationQueue, SerializationReply(), ::doSerializeVariables)
582581
}
583582

584583
private fun doSerializeVariables(args: SerializationArgs): SerializationReply {
@@ -593,7 +592,7 @@ class ReplForJupyterImpl(
593592
}
594593
log.debug("Serialization cellID: $cellId")
595594
log.debug("Serialization answer: ${resultMap.entries.first().value.fieldDescriptor}")
596-
return SerializationReply(cellId, resultMap)
595+
return SerializationReply(cellId, resultMap, args.comm_id)
597596
}
598597

599598

@@ -634,6 +633,7 @@ class ReplForJupyterImpl(
634633
var cellId: Int = -1,
635634
val topLevelVarName: String = "",
636635
val pathToDescriptor: List<String> = emptyList(),
636+
val comm_id: String = "",
637637
override val callback: (SerializationReply) -> Unit
638638
) : LockQueueArgs<SerializationReply>
639639

0 commit comments

Comments
 (0)