Skip to content

Commit f8d6b6b

Browse files
SChernykhzone117x
authored andcommitted
Cryptonight variant 2 support + tests (zone117x#64)
Reference code: monero-project/monero#4218
1 parent 18d2cbc commit f8d6b6b

File tree

3 files changed

+337
-21
lines changed

3 files changed

+337
-21
lines changed

crypto/variant2_int_sqrt.h

+166
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
#ifndef VARIANT2_INT_SQRT_H
2+
#define VARIANT2_INT_SQRT_H
3+
4+
#include <math.h>
5+
#include <float.h>
6+
7+
#define VARIANT2_INTEGER_MATH_SQRT_STEP_SSE2() \
8+
do { \
9+
const __m128i exp_double_bias = _mm_set_epi64x(0, 1023ULL << 52); \
10+
__m128d x = _mm_castsi128_pd(_mm_add_epi64(_mm_cvtsi64_si128(sqrt_input >> 12), exp_double_bias)); \
11+
x = _mm_sqrt_sd(_mm_setzero_pd(), x); \
12+
sqrt_result = (uint64_t)(_mm_cvtsi128_si64(_mm_sub_epi64(_mm_castpd_si128(x), exp_double_bias))) >> 19; \
13+
} while(0)
14+
15+
#define VARIANT2_INTEGER_MATH_SQRT_STEP_FP64() \
16+
do { \
17+
sqrt_result = sqrt(sqrt_input + 18446744073709551616.0) * 2.0 - 8589934592.0; \
18+
} while(0)
19+
20+
#define VARIANT2_INTEGER_MATH_SQRT_STEP_REF() \
21+
sqrt_result = integer_square_root_v2(sqrt_input)
22+
23+
// Reference implementation of the integer square root for Cryptonight variant 2
24+
// Computes integer part of "sqrt(2^64 + n) * 2 - 2^33"
25+
//
26+
// In other words, given 64-bit unsigned integer n:
27+
// 1) Write it as x = 1.NNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNN000... in binary (1 <= x < 2, all 64 bits of n are used)
28+
// 2) Calculate sqrt(x) = 1.0RRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRR... (1 <= sqrt(x) < sqrt(2), so it will always start with "1.0" in binary)
29+
// 3) Take 32 bits that come after "1.0" and return them as a 32-bit unsigned integer, discard all remaining bits
30+
//
31+
// Some sample inputs and outputs:
32+
//
33+
// Input | Output | Exact value of "sqrt(2^64 + n) * 2 - 2^33"
34+
// -----------------|------------|-------------------------------------------
35+
// 0 | 0 | 0
36+
// 2^32 | 0 | 0.99999999994179233909330885695244...
37+
// 2^32 + 1 | 1 | 1.0000000001746229827200734316305...
38+
// 2^50 | 262140 | 262140.00012206565608606978175873...
39+
// 2^55 + 20963331 | 8384515 | 8384515.9999999997673963974959744...
40+
// 2^55 + 20963332 | 8384516 | 8384516
41+
// 2^62 + 26599786 | 1013904242 | 1013904242.9999999999479374853545...
42+
// 2^62 + 26599787 | 1013904243 | 1013904243.0000000001561875439364...
43+
// 2^64 - 1 | 3558067407 | 3558067407.9041987696409179931096...
44+
45+
// The reference implementation as it is now uses only unsigned int64 arithmetic, so it can't have undefined behavior
46+
// It was tested once for all edge cases and confirmed correct
47+
//
48+
// !!! Note: if you're modifying this code, uncomment the test in monero/tests/hash/main.cpp !!!
49+
//
50+
static inline uint32_t integer_square_root_v2(uint64_t n)
51+
{
52+
uint64_t r = 1ULL << 63;
53+
54+
for (uint64_t bit = 1ULL << 60; bit; bit >>= 2)
55+
{
56+
const bool b = (n < r + bit);
57+
const uint64_t n_next = n - (r + bit);
58+
const uint64_t r_next = r + bit * 2;
59+
n = b ? n : n_next;
60+
r = b ? r : r_next;
61+
r >>= 1;
62+
}
63+
64+
return r * 2 + ((n > r) ? 1 : 0);
65+
}
66+
67+
/*
68+
VARIANT2_INTEGER_MATH_SQRT_FIXUP checks that "r" is an integer part of "sqrt(2^64 + sqrt_input) * 2 - 2^33" and adds or subtracts 1 if needed
69+
It's hard to understand how it works, so here is a full calculation of formulas used in VARIANT2_INTEGER_MATH_SQRT_FIXUP
70+
71+
The following inequalities must hold for r if it's an integer part of "sqrt(2^64 + sqrt_input) * 2 - 2^33":
72+
1) r <= sqrt(2^64 + sqrt_input) * 2 - 2^33
73+
2) r + 1 > sqrt(2^64 + sqrt_input) * 2 - 2^33
74+
75+
We need to check them using only unsigned integer arithmetic to avoid rounding errors and undefined behavior
76+
77+
First inequality: r <= sqrt(2^64 + sqrt_input) * 2 - 2^33
78+
-----------------------------------------------------------------------------------
79+
r <= sqrt(2^64 + sqrt_input) * 2 - 2^33
80+
r + 2^33 <= sqrt(2^64 + sqrt_input) * 2
81+
r/2 + 2^32 <= sqrt(2^64 + sqrt_input)
82+
(r/2 + 2^32)^2 <= 2^64 + sqrt_input
83+
84+
Rewrite r as r = s * 2 + b (s = trunc(r/2), b is 0 or 1)
85+
86+
((s*2+b)/2 + 2^32)^2 <= 2^64 + sqrt_input
87+
(s*2+b)^2/4 + 2*2^32*(s*2+b)/2 + 2^64 <= 2^64 + sqrt_input
88+
(s*2+b)^2/4 + 2*2^32*(s*2+b)/2 <= sqrt_input
89+
(s*2+b)^2/4 + 2^32*r <= sqrt_input
90+
(s^2*4+2*s*2*b+b^2)/4 + 2^32*r <= sqrt_input
91+
s^2+s*b+b^2/4 + 2^32*r <= sqrt_input
92+
s*(s+b) + b^2/4 + 2^32*r <= sqrt_input
93+
94+
Let r2 = s*(s+b) + r*2^32
95+
r2 + b^2/4 <= sqrt_input
96+
97+
If this inequality doesn't hold, then we must decrement r: IF "r2 + b^2/4 > sqrt_input" THEN r = r - 1
98+
99+
b can be 0 or 1
100+
If b is 0 then we need to compare "r2 > sqrt_input"
101+
If b is 1 then b^2/4 = 0.25, so we need to compare "r2 + 0.25 > sqrt_input"
102+
Since both r2 and sqrt_input are integers, we can safely replace it with "r2 + 1 > sqrt_input"
103+
-----------------------------------------------------------------------------------
104+
Both cases can be merged to a single expression "r2 + b > sqrt_input"
105+
-----------------------------------------------------------------------------------
106+
There will be no overflow when calculating "r2 + b", so it's safe to compare with sqrt_input:
107+
r2 + b = s*(s+b) + r*2^32 + b
108+
The largest value s, b and r can have is s = 1779033703, b = 1, r = 3558067407 when sqrt_input = 2^64 - 1
109+
r2 + b <= 1779033703*1779033704 + 3558067407*2^32 + 1 = 18446744068217447385 < 2^64
110+
111+
Second inequality: r + 1 > sqrt(2^64 + sqrt_input) * 2 - 2^33
112+
-----------------------------------------------------------------------------------
113+
r + 1 > sqrt(2^64 + sqrt_input) * 2 - 2^33
114+
r + 1 + 2^33 > sqrt(2^64 + sqrt_input) * 2
115+
((r+1)/2 + 2^32)^2 > 2^64 + sqrt_input
116+
117+
Rewrite r as r = s * 2 + b (s = trunc(r/2), b is 0 or 1)
118+
119+
((s*2+b+1)/2 + 2^32)^2 > 2^64 + sqrt_input
120+
(s*2+b+1)^2/4 + 2*(s*2+b+1)/2*2^32 + 2^64 > 2^64 + sqrt_input
121+
(s*2+b+1)^2/4 + (s*2+b+1)*2^32 > sqrt_input
122+
(s*2+b+1)^2/4 + (r+1)*2^32 > sqrt_input
123+
(s*2+(b+1))^2/4 + r*2^32 + 2^32 > sqrt_input
124+
(s^2*4+2*s*2*(b+1)+(b+1)^2)/4 + r*2^32 + 2^32 > sqrt_input
125+
s^2+s*(b+1)+(b+1)^2/4 + r*2^32 + 2^32 > sqrt_input
126+
s*(s+b) + s + (b+1)^2/4 + r*2^32 + 2^32 > sqrt_input
127+
128+
Let r2 = s*(s+b) + r*2^32
129+
130+
r2 + s + (b+1)^2/4 + 2^32 > sqrt_input
131+
r2 + 2^32 + (b+1)^2/4 > sqrt_input - s
132+
133+
If this inequality doesn't hold, then we must decrement r: IF "r2 + 2^32 + (b+1)^2/4 <= sqrt_input - s" THEN r = r - 1
134+
b can be 0 or 1
135+
If b is 0 then we need to compare "r2 + 2^32 + 1/4 <= sqrt_input - s" which is equal to "r2 + 2^32 < sqrt_input - s" because all numbers here are integers
136+
If b is 1 then (b+1)^2/4 = 1, so we need to compare "r2 + 2^32 + 1 <= sqrt_input - s" which is also equal to "r2 + 2^32 < sqrt_input - s"
137+
-----------------------------------------------------------------------------------
138+
Both cases can be merged to a single expression "r2 + 2^32 < sqrt_input - s"
139+
-----------------------------------------------------------------------------------
140+
There will be no overflow when calculating "r2 + 2^32":
141+
r2 + 2^32 = s*(s+b) + r*2^32 + 2^32 = s*(s+b) + (r+1)*2^32
142+
The largest value s, b and r can have is s = 1779033703, b = 1, r = 3558067407 when sqrt_input = 2^64 - 1
143+
r2 + b <= 1779033703*1779033704 + 3558067408*2^32 = 18446744072512414680 < 2^64
144+
145+
There will be no integer overflow when calculating "sqrt_input - s", i.e. "sqrt_input >= s" at all times:
146+
s = trunc(r/2) = trunc(sqrt(2^64 + sqrt_input) - 2^32) < sqrt(2^64 + sqrt_input) - 2^32 + 1
147+
sqrt_input > sqrt(2^64 + sqrt_input) - 2^32 + 1
148+
sqrt_input + 2^32 - 1 > sqrt(2^64 + sqrt_input)
149+
(sqrt_input + 2^32 - 1)^2 > sqrt_input + 2^64
150+
sqrt_input^2 + 2*sqrt_input*(2^32 - 1) + (2^32-1)^2 > sqrt_input + 2^64
151+
sqrt_input^2 + sqrt_input*(2^33 - 2) + (2^32-1)^2 > sqrt_input + 2^64
152+
sqrt_input^2 + sqrt_input*(2^33 - 3) + (2^32-1)^2 > 2^64
153+
sqrt_input^2 + sqrt_input*(2^33 - 3) + 2^64-2^33+1 > 2^64
154+
sqrt_input^2 + sqrt_input*(2^33 - 3) - 2^33 + 1 > 0
155+
This inequality is true if sqrt_input > 1 and it's easy to check that s = 0 if sqrt_input is 0 or 1, so there will be no integer overflow
156+
*/
157+
158+
#define VARIANT2_INTEGER_MATH_SQRT_FIXUP(r) \
159+
do { \
160+
const uint64_t s = r >> 1; \
161+
const uint64_t b = r & 1; \
162+
const uint64_t r2 = (uint64_t)(s) * (s + b) + (r << 32); \
163+
r += ((r2 + b > sqrt_input) ? -1 : 0) + ((r2 + (1ULL << 32) < sqrt_input - s) ? 1 : 0); \
164+
} while(0)
165+
166+
#endif

cryptonight.c

+116-21
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "crypto/c_skein.h"
1515
#include "crypto/int-util.h"
1616
#include "crypto/hash-ops.h"
17+
#include "crypto/variant2_int_sqrt.h"
1718

1819
#define MEMORY (1 << 21) /* 2 MiB */
1920
#define ITER (1 << 20)
@@ -23,7 +24,7 @@
2324
#define INIT_SIZE_BYTE (INIT_SIZE_BLK * AES_BLOCK_SIZE)
2425

2526
#define VARIANT1_1(p) \
26-
do if (variant > 0) \
27+
do if (variant == 1) \
2728
{ \
2829
const uint8_t tmp = ((const uint8_t*)(p))[11]; \
2930
static const uint32_t table = 0x75310; \
@@ -32,18 +33,98 @@
3233
} while(0)
3334

3435
#define VARIANT1_2(p) \
35-
do if (variant > 0) \
36+
do if (variant == 1) \
3637
{ \
3738
((uint64_t*)p)[1] ^= tweak1_2; \
3839
} while(0)
3940

4041
#define VARIANT1_INIT() \
41-
if (variant > 0 && len < 43) \
42+
if (variant == 1 && len < 43) \
4243
{ \
43-
fprintf(stderr, "Cryptonight variants need at least 43 bytes of data"); \
44+
fprintf(stderr, "Cryptonight variant 1 needs at least 43 bytes of data"); \
4445
_exit(1); \
4546
} \
46-
const uint64_t tweak1_2 = variant > 0 ? *(const uint64_t*)(((const uint8_t*)input)+35) ^ ctx->state.hs.w[24] : 0
47+
const uint64_t tweak1_2 = (variant == 1) ? *(const uint64_t*)(((const uint8_t*)input)+35) ^ ctx->state.hs.w[24] : 0
48+
49+
#define U64(p) ((uint64_t*)(p))
50+
51+
#define VARIANT2_INIT(b, state) \
52+
uint64_t division_result; \
53+
uint64_t sqrt_result; \
54+
do if (variant >= 2) \
55+
{ \
56+
U64(b)[2] = state.hs.w[8] ^ state.hs.w[10]; \
57+
U64(b)[3] = state.hs.w[9] ^ state.hs.w[11]; \
58+
division_result = state.hs.w[12]; \
59+
sqrt_result = state.hs.w[13]; \
60+
} while (0)
61+
62+
#define VARIANT2_SHUFFLE_ADD(base_ptr, offset, a, b) \
63+
do if (variant >= 2) \
64+
{ \
65+
uint64_t* chunk1 = U64((base_ptr) + ((offset) ^ 0x10)); \
66+
uint64_t* chunk2 = U64((base_ptr) + ((offset) ^ 0x20)); \
67+
uint64_t* chunk3 = U64((base_ptr) + ((offset) ^ 0x30)); \
68+
\
69+
const uint64_t chunk1_old[2] = { chunk1[0], chunk1[1] }; \
70+
\
71+
chunk1[0] = chunk3[0] + U64(b + 16)[0]; \
72+
chunk1[1] = chunk3[1] + U64(b + 16)[1]; \
73+
\
74+
chunk3[0] = chunk2[0] + U64(a)[0]; \
75+
chunk3[1] = chunk2[1] + U64(a)[1]; \
76+
\
77+
chunk2[0] = chunk1_old[0] + U64(b)[0]; \
78+
chunk2[1] = chunk1_old[1] + U64(b)[1]; \
79+
} while (0)
80+
81+
#define VARIANT2_INTEGER_MATH_DIVISION_STEP(b, ptr) \
82+
((uint64_t*)(b))[0] ^= division_result ^ (sqrt_result << 32); \
83+
{ \
84+
const uint64_t dividend = ((uint64_t*)(ptr))[1]; \
85+
const uint32_t divisor = (((uint32_t*)(ptr))[0] + (uint32_t)(sqrt_result << 1)) | 0x80000001UL; \
86+
division_result = ((uint32_t)(dividend / divisor)) + \
87+
(((uint64_t)(dividend % divisor)) << 32); \
88+
} \
89+
const uint64_t sqrt_input = ((uint64_t*)(ptr))[0] + division_result
90+
91+
#if defined(__x86_64__) || (defined(_MSC_VER) && defined(_WIN64))
92+
#include <emmintrin.h>
93+
94+
#if defined(_MSC_VER) || defined(__MINGW32__)
95+
#include <intrin.h>
96+
#else
97+
#include <wmmintrin.h>
98+
#endif
99+
100+
#define VARIANT2_INTEGER_MATH(b, ptr) \
101+
do if (variant >= 2) \
102+
{ \
103+
VARIANT2_INTEGER_MATH_DIVISION_STEP(b, ptr); \
104+
VARIANT2_INTEGER_MATH_SQRT_STEP_SSE2(); \
105+
VARIANT2_INTEGER_MATH_SQRT_FIXUP(sqrt_result); \
106+
} while (0)
107+
#else
108+
#if defined DBL_MANT_DIG && (DBL_MANT_DIG >= 50)
109+
// double precision floating point type has enough bits of precision on current platform
110+
#define VARIANT2_INTEGER_MATH(b, ptr) \
111+
do if (variant >= 2) \
112+
{ \
113+
VARIANT2_INTEGER_MATH_DIVISION_STEP(b, ptr); \
114+
VARIANT2_INTEGER_MATH_SQRT_STEP_FP64(); \
115+
VARIANT2_INTEGER_MATH_SQRT_FIXUP(sqrt_result); \
116+
} while (0)
117+
#else
118+
// double precision floating point type is not good enough on current platform
119+
// fall back to the reference code (integer only)
120+
#define VARIANT2_INTEGER_MATH(b, ptr) \
121+
do if (variant >= 2) \
122+
{ \
123+
VARIANT2_INTEGER_MATH_DIVISION_STEP(b, ptr); \
124+
VARIANT2_INTEGER_MATH_SQRT_STEP_REF(); \
125+
} while (0)
126+
#endif
127+
#endif
47128

48129
#pragma pack(push, 1)
49130
union cn_slow_hash_state {
@@ -88,16 +169,6 @@ static void mul(const uint8_t* a, const uint8_t* b, uint8_t* res) {
88169
((uint64_t*) res)[1] = mul128(((uint64_t*) a)[0], ((uint64_t*) b)[0], (uint64_t*) res);
89170
}
90171

91-
static void mul_sum_xor_dst(const uint8_t* a, uint8_t* c, uint8_t* dst) {
92-
uint64_t hi, lo = mul128(((uint64_t*) a)[0], ((uint64_t*) dst)[0], &hi) + ((uint64_t*) c)[1];
93-
hi += ((uint64_t*) c)[0];
94-
95-
((uint64_t*) c)[0] = ((uint64_t*) dst)[0] ^ hi;
96-
((uint64_t*) c)[1] = ((uint64_t*) dst)[1] ^ lo;
97-
((uint64_t*) dst)[0] = hi;
98-
((uint64_t*) dst)[1] = lo;
99-
}
100-
101172
static void sum_half_blocks(uint8_t* a, const uint8_t* b) {
102173
uint64_t a0, a1, b0, b1;
103174

@@ -141,7 +212,7 @@ struct cryptonight_ctx {
141212
union cn_slow_hash_state state;
142213
uint8_t text[INIT_SIZE_BYTE];
143214
uint8_t a[AES_BLOCK_SIZE];
144-
uint8_t b[AES_BLOCK_SIZE];
215+
uint8_t b[AES_BLOCK_SIZE * 2];
145216
uint8_t c[AES_BLOCK_SIZE];
146217
uint8_t aes_key[AES_KEY_SIZE];
147218
oaes_ctx* aes_ctx;
@@ -156,6 +227,7 @@ void cryptonight_hash(const char* input, char* output, uint32_t len, int variant
156227
size_t i, j;
157228

158229
VARIANT1_INIT();
230+
VARIANT2_INIT(ctx->b, ctx->state);
159231

160232
oaes_key_import_data(ctx->aes_ctx, ctx->aes_key, AES_KEY_SIZE);
161233
for (i = 0; i < MEMORY / INIT_SIZE_BYTE; i++) {
@@ -180,14 +252,37 @@ void cryptonight_hash(const char* input, char* output, uint32_t len, int variant
180252
/* Iteration 1 */
181253
j = e2i(ctx->a);
182254
aesb_single_round(&ctx->long_state[j * AES_BLOCK_SIZE], ctx->c, ctx->a);
255+
VARIANT2_SHUFFLE_ADD(ctx->long_state, j * AES_BLOCK_SIZE, ctx->a, ctx->b);
183256
xor_blocks_dst(ctx->c, ctx->b, &ctx->long_state[j * AES_BLOCK_SIZE]);
184-
VARIANT1_1((uint8_t*)&ctx->long_state[j * AES_BLOCK_SIZE]);
257+
VARIANT1_1((uint8_t*)&ctx->long_state[j * AES_BLOCK_SIZE]);
185258
/* Iteration 2 */
186-
mul_sum_xor_dst(ctx->c, ctx->a,
187-
&ctx->long_state[e2i(ctx->c) * AES_BLOCK_SIZE]);
259+
j = e2i(ctx->c);
260+
261+
uint64_t* dst = (uint64_t*)&ctx->long_state[j * AES_BLOCK_SIZE];
262+
263+
uint64_t t[2];
264+
t[0] = dst[0];
265+
t[1] = dst[1];
266+
267+
VARIANT2_INTEGER_MATH(t, ctx->c);
268+
269+
uint64_t hi;
270+
uint64_t lo = mul128(((uint64_t*)ctx->c)[0], t[0], &hi);
271+
272+
VARIANT2_SHUFFLE_ADD(ctx->long_state, j * AES_BLOCK_SIZE, ctx->a, ctx->b);
273+
274+
((uint64_t*)ctx->a)[0] += hi;
275+
((uint64_t*)ctx->a)[1] += lo;
276+
277+
dst[0] = ((uint64_t*)ctx->a)[0];
278+
dst[1] = ((uint64_t*)ctx->a)[1];
279+
280+
((uint64_t*)ctx->a)[0] ^= t[0];
281+
((uint64_t*)ctx->a)[1] ^= t[1];
282+
283+
VARIANT1_2((uint8_t*)&ctx->long_state[j * AES_BLOCK_SIZE]);
284+
copy_block(ctx->b + AES_BLOCK_SIZE, ctx->b);
188285
copy_block(ctx->b, ctx->c);
189-
VARIANT1_2((uint8_t*)
190-
&ctx->long_state[e2i(ctx->c) * AES_BLOCK_SIZE]);
191286
}
192287

193288
memcpy(ctx->text, ctx->state.init, INIT_SIZE_BYTE);

0 commit comments

Comments
 (0)