11
11
12
12
using System ;
13
13
using System . Collections . Generic ;
14
+ using System . Threading . Tasks ;
14
15
using static libsignalservice . messages . SignalServiceDataMessage ;
15
16
using static libsignalservice . push . DataMessage ;
16
17
@@ -56,84 +57,135 @@ public OutgoingPushMessage Encrypt(SignalProtocolAddress destination, byte[] unp
56
57
/// Decrypt a received <see cref="SignalServiceEnvelope"/>
57
58
/// </summary>
58
59
/// <param name="envelope">The received SignalServiceEnvelope</param>
60
+ /// <param name="callback">Optional callback to call during the decrypt process before it is acked</param>
59
61
/// <returns>a decrypted SignalServiceContent</returns>
60
- public SignalServiceContent Decrypt ( SignalServiceEnvelope envelope )
62
+ public async Task < SignalServiceContent > Decrypt ( SignalServiceEnvelope envelope , Func < SignalServiceContent , Task > callback = null )
61
63
{
64
+ Func < byte [ ] , Task > callback_func = null ;
65
+ if ( callback != null )
66
+ {
67
+ callback_func = async ( data ) => await callback ( await DecryptComplete ( envelope , data ) ) ;
68
+ }
62
69
try
63
70
{
64
- SignalServiceContent content = new SignalServiceContent ( ) ;
65
-
71
+ byte [ ] decrypted_data = null ;
66
72
if ( envelope . HasLegacyMessage ( ) )
67
73
{
68
- DataMessage message = DataMessage . Parser . ParseFrom ( Decrypt ( envelope , envelope . GetLegacyMessage ( ) ) ) ;
69
- content = new SignalServiceContent ( )
70
- {
71
- Message = CreateSignalServiceMessage ( envelope , message )
72
- } ;
74
+ decrypted_data = await Decrypt ( envelope , envelope . GetLegacyMessage ( ) , callback_func ) ;
73
75
}
74
76
else if ( envelope . HasContent ( ) )
75
77
{
76
- Content message = Content . Parser . ParseFrom ( Decrypt ( envelope , envelope . GetContent ( ) ) ) ;
78
+ decrypted_data = await Decrypt ( envelope , envelope . GetContent ( ) , callback_func ) ;
79
+ }
80
+ if ( callback_func != null )
81
+ {
82
+ return null ;
83
+ }
84
+ return await DecryptComplete ( envelope , decrypted_data ) ;
85
+ }
86
+ catch ( InvalidProtocolBufferException e )
87
+ {
88
+ throw new InvalidMessageException ( e ) ;
89
+ }
90
+ }
91
+ private Task < SignalServiceContent > DecryptComplete ( SignalServiceEnvelope envelope , byte [ ] decrypted_data )
92
+ {
93
+ SignalServiceContent content = new SignalServiceContent ( ) ;
94
+
95
+ if ( envelope . HasLegacyMessage ( ) )
96
+ {
97
+ DataMessage message = DataMessage . Parser . ParseFrom ( decrypted_data ) ;
98
+ content = new SignalServiceContent ( )
99
+ {
100
+ Message = CreateSignalServiceMessage ( envelope , message )
101
+ } ;
102
+ }
103
+ else if ( envelope . HasContent ( ) )
104
+ {
105
+ Content message = Content . Parser . ParseFrom ( decrypted_data ) ;
77
106
78
- if ( message . DataMessageOneofCase == Content . DataMessageOneofOneofCase . DataMessage )
107
+ if ( message . DataMessageOneofCase == Content . DataMessageOneofOneofCase . DataMessage )
108
+ {
109
+ content = new SignalServiceContent ( )
79
110
{
80
- content = new SignalServiceContent ( )
81
- {
82
- Message = CreateSignalServiceMessage ( envelope , message . DataMessage )
83
- } ;
84
- }
85
- else if ( message . SyncMessageOneofCase == Content . SyncMessageOneofOneofCase . SyncMessage && LocalAddress . E164number == envelope . GetSource ( ) )
111
+ Message = CreateSignalServiceMessage ( envelope , message . DataMessage )
112
+ } ;
113
+ }
114
+ else if ( message . SyncMessageOneofCase == Content . SyncMessageOneofOneofCase . SyncMessage && LocalAddress . E164number == envelope . GetSource ( ) )
115
+ {
116
+ content = new SignalServiceContent ( )
86
117
{
87
- content = new SignalServiceContent ( )
88
- {
89
- SynchronizeMessage = CreateSynchronizeMessage ( envelope , message . SyncMessage )
90
- } ;
91
- }
92
- else if ( message . CallMessageOneofCase == Content . CallMessageOneofOneofCase . CallMessage )
118
+ SynchronizeMessage = CreateSynchronizeMessage ( envelope , message . SyncMessage )
119
+ } ;
120
+ }
121
+ else if ( message . CallMessageOneofCase == Content . CallMessageOneofOneofCase . CallMessage )
122
+ {
123
+ content = new SignalServiceContent ( )
93
124
{
94
- content = new SignalServiceContent ( )
95
- {
96
- CallMessage = CreateCallMessage ( message . CallMessage )
97
- } ;
98
- }
99
- else if ( message . ReceiptMessageOneofCase == Content . ReceiptMessageOneofOneofCase . ReceiptMessage )
125
+ CallMessage = CreateCallMessage ( message . CallMessage )
126
+ } ;
127
+ }
128
+ else if ( message . ReceiptMessageOneofCase == Content . ReceiptMessageOneofOneofCase . ReceiptMessage )
129
+ {
130
+ content = new SignalServiceContent ( )
100
131
{
101
- content = new SignalServiceContent ( )
102
- {
103
- ReadMessage = CreateReceiptMessage ( envelope , message . ReceiptMessage )
104
- } ;
105
- }
132
+ ReadMessage = CreateReceiptMessage ( envelope , message . ReceiptMessage )
133
+ } ;
106
134
}
107
-
108
- return content ;
109
135
}
110
- catch ( InvalidProtocolBufferException e )
136
+
137
+ return Task . FromResult ( content ) ;
138
+ }
139
+ private class DecryptionCallbackHandler : DecryptionCallback
140
+ {
141
+ public Task handlePlaintext ( byte [ ] plaintext , SessionRecord sessionRecord )
111
142
{
112
- throw new InvalidMessageException ( e ) ;
143
+ return callback ( GetStrippedMessage ( sessionRecord , plaintext ) ) ;
113
144
}
145
+ public SessionCipher sessionCipher ;
146
+ public Func < byte [ ] , Task > callback ;
114
147
}
115
-
116
- private byte [ ] Decrypt ( SignalServiceEnvelope envelope , byte [ ] ciphertext )
148
+ private async Task < byte [ ] > Decrypt ( SignalServiceEnvelope envelope , byte [ ] ciphertext , Func < byte [ ] , Task > callback = null )
117
149
118
150
{
119
151
SignalProtocolAddress sourceAddress = new SignalProtocolAddress ( envelope . GetSource ( ) , ( uint ) envelope . GetSourceDevice ( ) ) ;
120
152
SessionCipher sessionCipher = new SessionCipher ( SignalProtocolStore , sourceAddress ) ;
121
153
122
154
byte [ ] paddedMessage ;
123
-
155
+ DecryptionCallbackHandler callback_handler = null ;
156
+ if ( callback != null )
157
+ callback_handler = new DecryptionCallbackHandler { callback = callback , sessionCipher = sessionCipher } ;
124
158
if ( envelope . IsPreKeySignalMessage ( ) )
125
159
{
160
+ if ( callback_handler != null )
161
+ {
162
+ await sessionCipher . decrypt ( new PreKeySignalMessage ( ciphertext ) , callback_handler ) ;
163
+ return null ;
164
+ }
126
165
paddedMessage = sessionCipher . decrypt ( new PreKeySignalMessage ( ciphertext ) ) ;
127
166
}
128
167
else if ( envelope . IsSignalMessage ( ) )
129
168
{
169
+ if ( callback_handler != null )
170
+ {
171
+ await sessionCipher . decrypt ( new SignalMessage ( ciphertext ) , callback_handler ) ;
172
+ return null ;
173
+ }
130
174
paddedMessage = sessionCipher . decrypt ( new SignalMessage ( ciphertext ) ) ;
131
175
}
132
176
else
133
177
{
134
178
throw new InvalidMessageException ( "Unknown type: " + envelope . GetEnvelopeType ( ) + " from " + envelope . GetSource ( ) ) ;
135
179
}
136
-
180
+ return GetStrippedMessage ( sessionCipher , paddedMessage ) ;
181
+ }
182
+ private static byte [ ] GetStrippedMessage ( SessionRecord sessionRecord , byte [ ] paddedMessage )
183
+ {
184
+ PushTransportDetails transportDetails = new PushTransportDetails ( sessionRecord . getSessionState ( ) . getSessionVersion ( ) ) ;
185
+ return transportDetails . GetStrippedPaddingMessageBody ( paddedMessage ) ;
186
+ }
187
+ private static byte [ ] GetStrippedMessage ( SessionCipher sessionCipher , byte [ ] paddedMessage )
188
+ {
137
189
PushTransportDetails transportDetails = new PushTransportDetails ( sessionCipher . getSessionVersion ( ) ) ;
138
190
return transportDetails . GetStrippedPaddingMessageBody ( paddedMessage ) ;
139
191
}
@@ -152,7 +204,7 @@ private SignalServiceDataMessage CreateSignalServiceMessage(SignalServiceEnvelop
152
204
attachments . Add ( CreateAttachmentPointer ( envelope . GetRelay ( ) , pointer ) ) ;
153
205
}
154
206
155
- if ( content . TimestampOneofCase == DataMessage . TimestampOneofOneofCase . Timestamp && ( long ) content . Timestamp != envelope . GetTimestamp ( ) )
207
+ if ( content . TimestampOneofCase == DataMessage . TimestampOneofOneofCase . Timestamp && ( long ) content . Timestamp != envelope . GetTimestamp ( ) )
156
208
{
157
209
throw new InvalidMessageException ( "Timestamps don't match: " + content . Timestamp + " vs " + envelope . GetTimestamp ( ) ) ;
158
210
}
@@ -290,7 +342,7 @@ private SignalServiceCallMessage CreateCallMessage(CallMessage content)
290
342
var l = new List < IceUpdateMessage > ( ) ;
291
343
foreach ( var u in content . IceUpdate )
292
344
{
293
- l . Add ( new IceUpdateMessage ( )
345
+ l . Add ( new IceUpdateMessage ( )
294
346
{
295
347
Id = u . Id ,
296
348
SdpMid = u . SdpMid ,
@@ -374,7 +426,7 @@ private SignalServiceDataMessage.SignalServiceQuote CreateQuote(SignalServiceEnv
374
426
pointer . ThumbnailOneofCase == Types . Quote . Types . QuotedAttachment . ThumbnailOneofOneofCase . Thumbnail ? CreateAttachmentPointer ( envelope . GetRelay ( ) , pointer . Thumbnail ) : null ) ) ;
375
427
}
376
428
377
- return new SignalServiceDataMessage . SignalServiceQuote ( ( long ) content . Quote . Id ,
429
+ return new SignalServiceDataMessage . SignalServiceQuote ( ( long ) content . Quote . Id ,
378
430
new SignalServiceAddress ( content . Quote . Author ) ,
379
431
content . Quote . Text ,
380
432
attachments ) ;
0 commit comments