Skip to content

Commit d5df931

Browse files
committed
Added callback support and as async (requires change to libsignal-protocol-dotnet for it)
1 parent db48599 commit d5df931

File tree

1 file changed

+96
-44
lines changed

1 file changed

+96
-44
lines changed

libsignal-service-dotnet/crypto/SignalServiceCipher.cs

+96-44
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
using System;
1313
using System.Collections.Generic;
14+
using System.Threading.Tasks;
1415
using static libsignalservice.messages.SignalServiceDataMessage;
1516
using static libsignalservice.push.DataMessage;
1617

@@ -56,84 +57,135 @@ public OutgoingPushMessage Encrypt(SignalProtocolAddress destination, byte[] unp
5657
/// Decrypt a received <see cref="SignalServiceEnvelope"/>
5758
/// </summary>
5859
/// <param name="envelope">The received SignalServiceEnvelope</param>
60+
/// <param name="callback">Optional callback to call during the decrypt process before it is acked</param>
5961
/// <returns>a decrypted SignalServiceContent</returns>
60-
public SignalServiceContent Decrypt(SignalServiceEnvelope envelope)
62+
public async Task<SignalServiceContent> Decrypt(SignalServiceEnvelope envelope, Func<SignalServiceContent, Task> callback = null)
6163
{
64+
Func<byte[], Task> callback_func = null;
65+
if (callback != null)
66+
{
67+
callback_func = async (data) => await callback(await DecryptComplete(envelope, data));
68+
}
6269
try
6370
{
64-
SignalServiceContent content = new SignalServiceContent();
65-
71+
byte[] decrypted_data = null;
6672
if (envelope.HasLegacyMessage())
6773
{
68-
DataMessage message = DataMessage.Parser.ParseFrom(Decrypt(envelope, envelope.GetLegacyMessage()));
69-
content = new SignalServiceContent()
70-
{
71-
Message = CreateSignalServiceMessage(envelope, message)
72-
};
74+
decrypted_data = await Decrypt(envelope, envelope.GetLegacyMessage(), callback_func);
7375
}
7476
else if (envelope.HasContent())
7577
{
76-
Content message = Content.Parser.ParseFrom(Decrypt(envelope, envelope.GetContent()));
78+
decrypted_data = await Decrypt(envelope, envelope.GetContent(), callback_func);
79+
}
80+
if (callback_func != null)
81+
{
82+
return null;
83+
}
84+
return await DecryptComplete(envelope, decrypted_data);
85+
}
86+
catch (InvalidProtocolBufferException e)
87+
{
88+
throw new InvalidMessageException(e);
89+
}
90+
}
91+
private Task<SignalServiceContent> DecryptComplete(SignalServiceEnvelope envelope, byte[] decrypted_data)
92+
{
93+
SignalServiceContent content = new SignalServiceContent();
94+
95+
if (envelope.HasLegacyMessage())
96+
{
97+
DataMessage message = DataMessage.Parser.ParseFrom(decrypted_data);
98+
content = new SignalServiceContent()
99+
{
100+
Message = CreateSignalServiceMessage(envelope, message)
101+
};
102+
}
103+
else if (envelope.HasContent())
104+
{
105+
Content message = Content.Parser.ParseFrom(decrypted_data);
77106

78-
if (message.DataMessageOneofCase == Content.DataMessageOneofOneofCase.DataMessage)
107+
if (message.DataMessageOneofCase == Content.DataMessageOneofOneofCase.DataMessage)
108+
{
109+
content = new SignalServiceContent()
79110
{
80-
content = new SignalServiceContent()
81-
{
82-
Message = CreateSignalServiceMessage(envelope, message.DataMessage)
83-
};
84-
}
85-
else if (message.SyncMessageOneofCase == Content.SyncMessageOneofOneofCase.SyncMessage && LocalAddress.E164number == envelope.GetSource())
111+
Message = CreateSignalServiceMessage(envelope, message.DataMessage)
112+
};
113+
}
114+
else if (message.SyncMessageOneofCase == Content.SyncMessageOneofOneofCase.SyncMessage && LocalAddress.E164number == envelope.GetSource())
115+
{
116+
content = new SignalServiceContent()
86117
{
87-
content = new SignalServiceContent()
88-
{
89-
SynchronizeMessage = CreateSynchronizeMessage(envelope, message.SyncMessage)
90-
};
91-
}
92-
else if (message.CallMessageOneofCase == Content.CallMessageOneofOneofCase.CallMessage)
118+
SynchronizeMessage = CreateSynchronizeMessage(envelope, message.SyncMessage)
119+
};
120+
}
121+
else if (message.CallMessageOneofCase == Content.CallMessageOneofOneofCase.CallMessage)
122+
{
123+
content = new SignalServiceContent()
93124
{
94-
content = new SignalServiceContent()
95-
{
96-
CallMessage = CreateCallMessage(message.CallMessage)
97-
};
98-
}
99-
else if (message.ReceiptMessageOneofCase == Content.ReceiptMessageOneofOneofCase.ReceiptMessage)
125+
CallMessage = CreateCallMessage(message.CallMessage)
126+
};
127+
}
128+
else if (message.ReceiptMessageOneofCase == Content.ReceiptMessageOneofOneofCase.ReceiptMessage)
129+
{
130+
content = new SignalServiceContent()
100131
{
101-
content = new SignalServiceContent()
102-
{
103-
ReadMessage = CreateReceiptMessage(envelope, message.ReceiptMessage)
104-
};
105-
}
132+
ReadMessage = CreateReceiptMessage(envelope, message.ReceiptMessage)
133+
};
106134
}
107-
108-
return content;
109135
}
110-
catch (InvalidProtocolBufferException e)
136+
137+
return Task.FromResult(content);
138+
}
139+
private class DecryptionCallbackHandler : DecryptionCallback
140+
{
141+
public Task handlePlaintext(byte[] plaintext, SessionRecord sessionRecord)
111142
{
112-
throw new InvalidMessageException(e);
143+
return callback(GetStrippedMessage(sessionRecord, plaintext));
113144
}
145+
public SessionCipher sessionCipher;
146+
public Func<byte[], Task> callback;
114147
}
115-
116-
private byte[] Decrypt(SignalServiceEnvelope envelope, byte[] ciphertext)
148+
private async Task<byte[]> Decrypt(SignalServiceEnvelope envelope, byte[] ciphertext, Func<byte[], Task> callback = null)
117149

118150
{
119151
SignalProtocolAddress sourceAddress = new SignalProtocolAddress(envelope.GetSource(), (uint)envelope.GetSourceDevice());
120152
SessionCipher sessionCipher = new SessionCipher(SignalProtocolStore, sourceAddress);
121153

122154
byte[] paddedMessage;
123-
155+
DecryptionCallbackHandler callback_handler = null;
156+
if (callback != null)
157+
callback_handler = new DecryptionCallbackHandler { callback = callback, sessionCipher = sessionCipher };
124158
if (envelope.IsPreKeySignalMessage())
125159
{
160+
if (callback_handler != null)
161+
{
162+
await sessionCipher.decrypt(new PreKeySignalMessage(ciphertext), callback_handler);
163+
return null;
164+
}
126165
paddedMessage = sessionCipher.decrypt(new PreKeySignalMessage(ciphertext));
127166
}
128167
else if (envelope.IsSignalMessage())
129168
{
169+
if (callback_handler != null)
170+
{
171+
await sessionCipher.decrypt(new SignalMessage(ciphertext), callback_handler);
172+
return null;
173+
}
130174
paddedMessage = sessionCipher.decrypt(new SignalMessage(ciphertext));
131175
}
132176
else
133177
{
134178
throw new InvalidMessageException("Unknown type: " + envelope.GetEnvelopeType() + " from " + envelope.GetSource());
135179
}
136-
180+
return GetStrippedMessage(sessionCipher, paddedMessage);
181+
}
182+
private static byte[] GetStrippedMessage(SessionRecord sessionRecord, byte[] paddedMessage)
183+
{
184+
PushTransportDetails transportDetails = new PushTransportDetails(sessionRecord.getSessionState().getSessionVersion());
185+
return transportDetails.GetStrippedPaddingMessageBody(paddedMessage);
186+
}
187+
private static byte[] GetStrippedMessage(SessionCipher sessionCipher, byte[] paddedMessage)
188+
{
137189
PushTransportDetails transportDetails = new PushTransportDetails(sessionCipher.getSessionVersion());
138190
return transportDetails.GetStrippedPaddingMessageBody(paddedMessage);
139191
}
@@ -152,7 +204,7 @@ private SignalServiceDataMessage CreateSignalServiceMessage(SignalServiceEnvelop
152204
attachments.Add(CreateAttachmentPointer(envelope.GetRelay(), pointer));
153205
}
154206

155-
if (content.TimestampOneofCase == DataMessage.TimestampOneofOneofCase.Timestamp && (long) content.Timestamp != envelope.GetTimestamp())
207+
if (content.TimestampOneofCase == DataMessage.TimestampOneofOneofCase.Timestamp && (long)content.Timestamp != envelope.GetTimestamp())
156208
{
157209
throw new InvalidMessageException("Timestamps don't match: " + content.Timestamp + " vs " + envelope.GetTimestamp());
158210
}
@@ -290,7 +342,7 @@ private SignalServiceCallMessage CreateCallMessage(CallMessage content)
290342
var l = new List<IceUpdateMessage>();
291343
foreach (var u in content.IceUpdate)
292344
{
293-
l.Add(new IceUpdateMessage()
345+
l.Add(new IceUpdateMessage()
294346
{
295347
Id = u.Id,
296348
SdpMid = u.SdpMid,
@@ -374,7 +426,7 @@ private SignalServiceDataMessage.SignalServiceQuote CreateQuote(SignalServiceEnv
374426
pointer.ThumbnailOneofCase == Types.Quote.Types.QuotedAttachment.ThumbnailOneofOneofCase.Thumbnail ? CreateAttachmentPointer(envelope.GetRelay(), pointer.Thumbnail) : null));
375427
}
376428

377-
return new SignalServiceDataMessage.SignalServiceQuote((long) content.Quote.Id,
429+
return new SignalServiceDataMessage.SignalServiceQuote((long)content.Quote.Id,
378430
new SignalServiceAddress(content.Quote.Author),
379431
content.Quote.Text,
380432
attachments);

0 commit comments

Comments
 (0)