Skip to content

Commit 5b293ec

Browse files
committed
Separate simd equation impl and headers
1 parent e9287b1 commit 5b293ec

6 files changed

+106
-101
lines changed

CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ add_library(
2525
source/mandelbrot/mandelbrot_window.cpp
2626
source/graphics/color_conversions/color_conversions.cpp
2727
source/graphics/aspect_ratio/aspect_ratio.cpp
28+
source/mandelbrot/equations_simd.cpp
2829
)
2930

3031
target_include_directories(

source/mandelbrot/equations.hpp

+7-70
Original file line numberDiff line numberDiff line change
@@ -3,90 +3,27 @@
33
#include "config.hpp"
44
#include "units.hpp"
55

6-
#include <immintrin.h>
7-
86
namespace fractal {
97
// https://en.wikipedia.org/wiki/Mandelbrot_set#Formal_definition
10-
118
inline std::complex<complex_underlying>
129
step(std::complex<complex_underlying> z_n, std::complex<complex_underlying> constant)
1310
{
1411
return z_n * z_n + constant;
1512
}
1613

17-
inline std::array<iteration_count, 8> compute_iterations(
18-
std::array<std::complex<complex_underlying>, 8> z_0,
19-
std::array<std::complex<complex_underlying>, 8> constant, iteration_count max_iters
14+
inline iteration_count compute_iterations(
15+
std::complex<complex_underlying> z_0, std::complex<complex_underlying> constant,
16+
iteration_count max_iters
2017
)
2118
{
22-
static const auto SQUARED_DIVERGENCE =
23-
MANDELBROT_DIVERGENCE_NORM * MANDELBROT_DIVERGENCE_NORM;
24-
25-
alignas(64) std::array<double, 8> reals = {z_0[0].real(), z_0[1].real(),
26-
z_0[2].real(), z_0[3].real(),
27-
z_0[4].real(), z_0[5].real(),
28-
z_0[6].real(), z_0[7].real()};
29-
alignas(64) std::array<double, 8> imags = {z_0[0].imag(), z_0[1].imag(),
30-
z_0[2].imag(), z_0[3].imag(),
31-
z_0[4].imag(), z_0[5].imag(),
32-
z_0[6].imag(), z_0[7].imag()};
33-
alignas(64) std::array<double, 8> const_reals = {
34-
constant[0].real(), constant[1].real(), constant[2].real(), constant[3].real(),
35-
constant[4].real(), constant[5].real(), constant[6].real(), constant[7].real()
36-
};
37-
alignas(64) std::array<double, 8> const_imags = {
38-
constant[0].imag(), constant[1].imag(), constant[2].imag(), constant[3].imag(),
39-
constant[4].imag(), constant[5].imag(), constant[6].imag(), constant[7].imag()
40-
};
41-
42-
std::array<iteration_count, 8> solved_its = {0};
4319
iteration_count iterations = 0;
20+
std::complex<complex_underlying> z_n = z_0;
4421

45-
__m512d input_vec_real = _mm512_load_pd(reals.data());
46-
__m512d input_vec_imag = _mm512_load_pd(imags.data());
47-
__m512d input_vec_constant_imags = _mm512_load_pd(const_imags.data());
48-
__m512d input_vec_constant_reals = _mm512_load_pd(const_reals.data());
49-
__m512i solved_its_vec = _mm512_loadu_epi16(solved_its.data());
50-
51-
while (iterations < max_iters) {
52-
// Square real
53-
__m512d squared_vec_real = _mm512_mul_pd(input_vec_real, input_vec_real);
54-
55-
// Square imag
56-
__m512d squared_vec_imag = _mm512_mul_pd(input_vec_imag, input_vec_imag);
57-
58-
// Create imags
59-
__m512d real_x2 = _mm512_mul_pd(input_vec_real, _mm512_set1_pd(2));
60-
input_vec_imag =
61-
_mm512_fmadd_pd(real_x2, input_vec_imag, input_vec_constant_imags);
62-
63-
// Create reals
64-
__m512d subtracted_squared = _mm512_sub_pd(squared_vec_real, squared_vec_imag);
65-
input_vec_real = _mm512_add_pd(subtracted_squared, input_vec_constant_reals);
66-
67-
// Create squared norms
68-
__m512d squared_norms_vec = _mm512_add_pd(squared_vec_real, squared_vec_imag);
69-
__mmask8 solved_mask = _mm512_cmp_pd_mask(
70-
squared_norms_vec, _mm512_set1_pd(SQUARED_DIVERGENCE), _CMP_GT_OS
71-
);
72-
73-
uint32_t solved = _cvtmask8_u32(solved_mask);
74-
solved_its_vec = _mm512_mask_blend_epi16(
75-
solved_mask, solved_its_vec, _mm512_set1_epi16(iterations)
76-
);
77-
if (solved == 0xFF) [[unlikely]]
78-
break;
79-
22+
while (iterations < max_iters && std::norm(z_n) < MANDELBROT_DIVERGENCE_NORM) {
23+
z_n = step(z_n, constant);
8024
++iterations;
8125
}
8226

83-
_mm512_storeu_epi16(solved_its.data(), solved_its_vec);
84-
for (int i = 0; i < 8; i++) {
85-
if (solved_its[i] == 0) {
86-
solved_its[i] = max_iters;
87-
}
88-
}
89-
90-
return solved_its;
27+
return iterations;
9128
}
9229
} // namespace fractal

source/mandelbrot/equations_nsd.hpp

-29
This file was deleted.

source/mandelbrot/equations_simd.cpp

+77
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
#include "equations.hpp"
2+
3+
namespace fractal {
4+
std::array<iteration_count, 8> compute_iterations(
5+
const std::array<std::complex<complex_underlying>, 8>& z_0,
6+
const std::array<std::complex<complex_underlying>, 8>& constant,
7+
iteration_count max_iters
8+
)
9+
{
10+
static const auto SQUARED_DIVERGENCE =
11+
MANDELBROT_DIVERGENCE_NORM * MANDELBROT_DIVERGENCE_NORM;
12+
13+
alignas(64) std::array<double, 8> reals = {z_0[0].real(), z_0[1].real(),
14+
z_0[2].real(), z_0[3].real(),
15+
z_0[4].real(), z_0[5].real(),
16+
z_0[6].real(), z_0[7].real()};
17+
alignas(64) std::array<double, 8> imags = {z_0[0].imag(), z_0[1].imag(),
18+
z_0[2].imag(), z_0[3].imag(),
19+
z_0[4].imag(), z_0[5].imag(),
20+
z_0[6].imag(), z_0[7].imag()};
21+
alignas(64) std::array<double, 8> const_reals = {
22+
constant[0].real(), constant[1].real(), constant[2].real(), constant[3].real(),
23+
constant[4].real(), constant[5].real(), constant[6].real(), constant[7].real()
24+
};
25+
alignas(64) std::array<double, 8> const_imags = {
26+
constant[0].imag(), constant[1].imag(), constant[2].imag(), constant[3].imag(),
27+
constant[4].imag(), constant[5].imag(), constant[6].imag(), constant[7].imag()
28+
};
29+
30+
std::array<iteration_count, 8> solved_its = {0};
31+
32+
__m512d input_vec_real = _mm512_load_pd(reals.data());
33+
__m512d input_vec_imag = _mm512_load_pd(imags.data());
34+
__m512d input_vec_constant_imags = _mm512_load_pd(const_imags.data());
35+
__m512d input_vec_constant_reals = _mm512_load_pd(const_reals.data());
36+
__m512i solved_its_vec = _mm512_loadu_epi16(solved_its.data());
37+
38+
for (iteration_count iterations = 0; iterations < max_iters; iterations++) {
39+
// Square real
40+
__m512d squared_vec_real = _mm512_mul_pd(input_vec_real, input_vec_real);
41+
42+
// Square imag
43+
__m512d squared_vec_imag = _mm512_mul_pd(input_vec_imag, input_vec_imag);
44+
45+
// Create imags
46+
__m512d real_x2 = _mm512_mul_pd(input_vec_real, _mm512_set1_pd(2));
47+
input_vec_imag =
48+
_mm512_fmadd_pd(real_x2, input_vec_imag, input_vec_constant_imags);
49+
50+
// Create reals
51+
__m512d subtracted_squared = _mm512_sub_pd(squared_vec_real, squared_vec_imag);
52+
input_vec_real = _mm512_add_pd(subtracted_squared, input_vec_constant_reals);
53+
54+
// Create squared norms
55+
__m512d squared_norms_vec = _mm512_add_pd(squared_vec_real, squared_vec_imag);
56+
__mmask8 solved_mask = _mm512_cmp_pd_mask(
57+
squared_norms_vec, _mm512_set1_pd(SQUARED_DIVERGENCE), _CMP_GT_OS
58+
);
59+
60+
uint32_t solved = _cvtmask8_u32(solved_mask);
61+
solved_its_vec = _mm512_mask_blend_epi16(
62+
solved_mask, solved_its_vec,
63+
_mm512_set1_epi16(static_cast<int16_t>(iterations))
64+
);
65+
if (solved == 0xFF) [[unlikely]]
66+
break;
67+
}
68+
69+
__mmask32 mask = _mm512_cmpeq_epi16_mask(solved_its_vec, _mm512_set1_epi16(0));
70+
solved_its_vec = _mm512_mask_mov_epi16(
71+
solved_its_vec, mask, _mm512_set1_epi16(static_cast<int16_t>(max_iters))
72+
);
73+
_mm512_storeu_epi16(solved_its.data(), solved_its_vec);
74+
75+
return solved_its;
76+
}
77+
} // namespace fractal

source/mandelbrot/equations_simd.hpp

+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
#pragma once
2+
3+
#include "config.hpp"
4+
#include "units.hpp"
5+
6+
#include <immintrin.h>
7+
8+
namespace fractal {
9+
// https://en.wikipedia.org/wiki/Mandelbrot_set#Formal_definition
10+
11+
std::complex<complex_underlying>
12+
step(std::complex<complex_underlying> z_n, std::complex<complex_underlying> constant);
13+
14+
std::array<iteration_count, 8> compute_iterations(
15+
const std::array<std::complex<complex_underlying>, 8>& z_0,
16+
const std::array<std::complex<complex_underlying>, 8>& constant,
17+
iteration_count max_iters
18+
);
19+
} // namespace fractal

source/mandelbrot/mandelbrot_window.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
#include "config.hpp"
44
#include "coordinates.hpp"
5-
#include "equations.hpp"
5+
#include "equations_simd.hpp"
66
#include "graphics/aspect_ratio/aspect_ratio.hpp"
77
#include "graphics/color_conversions/color_conversions.hpp"
88
#include "graphics/display_to_complex.hpp"
@@ -28,7 +28,7 @@ void MandelbrotWindow::draw_coordinate_(
2828
const std::array<std::complex<complex_underlying>, 8>& complex_coords
2929
)
3030
{
31-
std::array<std::complex<complex_underlying>, 8> starts = {
31+
static constexpr std::array<std::complex<complex_underlying>, 8> starts = {
3232
std::complex<complex_underlying>{0, 0}
3333
};
3434
auto iterations =

0 commit comments

Comments
 (0)