Skip to content

Commit 3b16878

Browse files
committed
Added callback support and as async (requires change to libsignal-protocol-dotnet for it)
1 parent db48599 commit 3b16878

File tree

1 file changed

+90
-44
lines changed

1 file changed

+90
-44
lines changed

libsignal-service-dotnet/crypto/SignalServiceCipher.cs

+90-44
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
using System;
1313
using System.Collections.Generic;
14+
using System.Threading.Tasks;
1415
using static libsignalservice.messages.SignalServiceDataMessage;
1516
using static libsignalservice.push.DataMessage;
1617

@@ -57,83 +58,128 @@ public OutgoingPushMessage Encrypt(SignalProtocolAddress destination, byte[] unp
5758
/// </summary>
5859
/// <param name="envelope">The received SignalServiceEnvelope</param>
5960
/// <returns>a decrypted SignalServiceContent</returns>
60-
public SignalServiceContent Decrypt(SignalServiceEnvelope envelope)
61+
public async Task<SignalServiceContent> Decrypt(SignalServiceEnvelope envelope, Func<SignalServiceContent, Task> callback = null)
6162
{
63+
Func<byte[], Task> callback_func = null;
64+
if (callback != null)
65+
{
66+
callback_func = async (data) => await callback(await DecryptComplete(envelope, data));
67+
}
6268
try
6369
{
64-
SignalServiceContent content = new SignalServiceContent();
65-
70+
byte[] decrypted_data = null;
6671
if (envelope.HasLegacyMessage())
6772
{
68-
DataMessage message = DataMessage.Parser.ParseFrom(Decrypt(envelope, envelope.GetLegacyMessage()));
69-
content = new SignalServiceContent()
70-
{
71-
Message = CreateSignalServiceMessage(envelope, message)
72-
};
73+
decrypted_data = await Decrypt(envelope, envelope.GetLegacyMessage(), callback_func);
7374
}
7475
else if (envelope.HasContent())
7576
{
76-
Content message = Content.Parser.ParseFrom(Decrypt(envelope, envelope.GetContent()));
77+
decrypted_data = await Decrypt(envelope, envelope.GetContent(), callback_func);
78+
}
79+
if (callback_func != null)
80+
{
81+
return null;
82+
}
83+
return await DecryptComplete(envelope, decrypted_data);
84+
}
85+
catch (InvalidProtocolBufferException e)
86+
{
87+
throw new InvalidMessageException(e);
88+
}
89+
}
90+
private async Task<SignalServiceContent> DecryptComplete(SignalServiceEnvelope envelope, byte[] decrypted_data)
91+
{
92+
SignalServiceContent content = new SignalServiceContent();
93+
94+
if (envelope.HasLegacyMessage())
95+
{
96+
DataMessage message = DataMessage.Parser.ParseFrom(decrypted_data);
97+
content = new SignalServiceContent()
98+
{
99+
Message = CreateSignalServiceMessage(envelope, message)
100+
};
101+
}
102+
else if (envelope.HasContent())
103+
{
104+
Content message = Content.Parser.ParseFrom(decrypted_data);
77105

78-
if (message.DataMessageOneofCase == Content.DataMessageOneofOneofCase.DataMessage)
106+
if (message.DataMessageOneofCase == Content.DataMessageOneofOneofCase.DataMessage)
107+
{
108+
content = new SignalServiceContent()
79109
{
80-
content = new SignalServiceContent()
81-
{
82-
Message = CreateSignalServiceMessage(envelope, message.DataMessage)
83-
};
84-
}
85-
else if (message.SyncMessageOneofCase == Content.SyncMessageOneofOneofCase.SyncMessage && LocalAddress.E164number == envelope.GetSource())
110+
Message = CreateSignalServiceMessage(envelope, message.DataMessage)
111+
};
112+
}
113+
else if (message.SyncMessageOneofCase == Content.SyncMessageOneofOneofCase.SyncMessage && LocalAddress.E164number == envelope.GetSource())
114+
{
115+
content = new SignalServiceContent()
86116
{
87-
content = new SignalServiceContent()
88-
{
89-
SynchronizeMessage = CreateSynchronizeMessage(envelope, message.SyncMessage)
90-
};
91-
}
92-
else if (message.CallMessageOneofCase == Content.CallMessageOneofOneofCase.CallMessage)
117+
SynchronizeMessage = CreateSynchronizeMessage(envelope, message.SyncMessage)
118+
};
119+
}
120+
else if (message.CallMessageOneofCase == Content.CallMessageOneofOneofCase.CallMessage)
121+
{
122+
content = new SignalServiceContent()
93123
{
94-
content = new SignalServiceContent()
95-
{
96-
CallMessage = CreateCallMessage(message.CallMessage)
97-
};
98-
}
99-
else if (message.ReceiptMessageOneofCase == Content.ReceiptMessageOneofOneofCase.ReceiptMessage)
124+
CallMessage = CreateCallMessage(message.CallMessage)
125+
};
126+
}
127+
else if (message.ReceiptMessageOneofCase == Content.ReceiptMessageOneofOneofCase.ReceiptMessage)
128+
{
129+
content = new SignalServiceContent()
100130
{
101-
content = new SignalServiceContent()
102-
{
103-
ReadMessage = CreateReceiptMessage(envelope, message.ReceiptMessage)
104-
};
105-
}
131+
ReadMessage = CreateReceiptMessage(envelope, message.ReceiptMessage)
132+
};
106133
}
107-
108-
return content;
109134
}
110-
catch (InvalidProtocolBufferException e)
135+
136+
return content;
137+
}
138+
private class DecryptionCallbackHandler : DecryptionCallback
139+
{
140+
public Task handlePlaintext(byte[] plaintext, SessionRecord sessionRecord)
111141
{
112-
throw new InvalidMessageException(e);
142+
return callback(GetStrippedMessage(sessionCipher, plaintext));
113143
}
144+
public SessionCipher sessionCipher;
145+
public Func<byte[], Task> callback;
114146
}
115-
116-
private byte[] Decrypt(SignalServiceEnvelope envelope, byte[] ciphertext)
147+
private async Task<byte[]> Decrypt(SignalServiceEnvelope envelope, byte[] ciphertext, Func<byte[], Task> callback = null)
117148

118149
{
119150
SignalProtocolAddress sourceAddress = new SignalProtocolAddress(envelope.GetSource(), (uint)envelope.GetSourceDevice());
120151
SessionCipher sessionCipher = new SessionCipher(SignalProtocolStore, sourceAddress);
121152

122153
byte[] paddedMessage;
123-
154+
DecryptionCallbackHandler callback_handler = null;
155+
if (callback != null)
156+
callback_handler = new DecryptionCallbackHandler { callback = callback, sessionCipher = sessionCipher };
124157
if (envelope.IsPreKeySignalMessage())
125158
{
159+
if (callback_handler != null)
160+
{
161+
await sessionCipher.decrypt(new PreKeySignalMessage(ciphertext), callback_handler);
162+
return null;
163+
}
126164
paddedMessage = sessionCipher.decrypt(new PreKeySignalMessage(ciphertext));
127165
}
128166
else if (envelope.IsSignalMessage())
129167
{
168+
if (callback_handler != null)
169+
{
170+
await sessionCipher.decrypt(new SignalMessage(ciphertext), callback_handler);
171+
return null;
172+
}
130173
paddedMessage = sessionCipher.decrypt(new SignalMessage(ciphertext));
131174
}
132175
else
133176
{
134177
throw new InvalidMessageException("Unknown type: " + envelope.GetEnvelopeType() + " from " + envelope.GetSource());
135178
}
136-
179+
return GetStrippedMessage(sessionCipher, paddedMessage);
180+
}
181+
private static byte[] GetStrippedMessage(SessionCipher sessionCipher, byte[] paddedMessage)
182+
{
137183
PushTransportDetails transportDetails = new PushTransportDetails(sessionCipher.getSessionVersion());
138184
return transportDetails.GetStrippedPaddingMessageBody(paddedMessage);
139185
}
@@ -152,7 +198,7 @@ private SignalServiceDataMessage CreateSignalServiceMessage(SignalServiceEnvelop
152198
attachments.Add(CreateAttachmentPointer(envelope.GetRelay(), pointer));
153199
}
154200

155-
if (content.TimestampOneofCase == DataMessage.TimestampOneofOneofCase.Timestamp && (long) content.Timestamp != envelope.GetTimestamp())
201+
if (content.TimestampOneofCase == DataMessage.TimestampOneofOneofCase.Timestamp && (long)content.Timestamp != envelope.GetTimestamp())
156202
{
157203
throw new InvalidMessageException("Timestamps don't match: " + content.Timestamp + " vs " + envelope.GetTimestamp());
158204
}
@@ -290,7 +336,7 @@ private SignalServiceCallMessage CreateCallMessage(CallMessage content)
290336
var l = new List<IceUpdateMessage>();
291337
foreach (var u in content.IceUpdate)
292338
{
293-
l.Add(new IceUpdateMessage()
339+
l.Add(new IceUpdateMessage()
294340
{
295341
Id = u.Id,
296342
SdpMid = u.SdpMid,
@@ -374,7 +420,7 @@ private SignalServiceDataMessage.SignalServiceQuote CreateQuote(SignalServiceEnv
374420
pointer.ThumbnailOneofCase == Types.Quote.Types.QuotedAttachment.ThumbnailOneofOneofCase.Thumbnail ? CreateAttachmentPointer(envelope.GetRelay(), pointer.Thumbnail) : null));
375421
}
376422

377-
return new SignalServiceDataMessage.SignalServiceQuote((long) content.Quote.Id,
423+
return new SignalServiceDataMessage.SignalServiceQuote((long)content.Quote.Id,
378424
new SignalServiceAddress(content.Quote.Author),
379425
content.Quote.Text,
380426
attachments);

0 commit comments

Comments
 (0)