1+ #pragma once
2+
3+ #include < snmalloc/aal/aal.h>
4+ #include < snmalloc/stl/atomic.h>
5+ #include < stddef.h>
6+
7+ #ifdef SNMALLOC_PTHREAD_ATFORK_WORKS
8+ # include < pthread.h>
9+ #endif
10+
11+ namespace snmalloc
12+ {
13+ // This is a simple implementation of a class that can be
14+ // used to prevent a process from forking. Holding a lock
15+ // in the allocator while forking can lead to deadlocks.
16+ // This causes the fork to wait out any other threads inside
17+ // the allocators locks.
18+ //
19+ // The use is
20+ // ```
21+ // {
22+ // PreventFork pf;
23+ // // Code that should not be running during a fork.
24+ // }
25+ // ```
26+ class PreventFork
27+ {
28+ // Global atomic counter of the number of threads currently preventing the
29+ // system from forking. The bottom bit is used to signal that a thread is
30+ // wanting to fork.
31+ static inline stl::Atomic<size_t > threads_preventing_fork{0 };
32+
33+ // The depth of the current thread's prevention of forking.
34+ // This is used to enable reentrant prevention of forking.
35+ static inline thread_local size_t depth_of_prevention{0 };
36+
37+ // There could be multiple copies of the atfork handler installed.
38+ // Only perform work for the first prefork and final postfork.
39+ static inline thread_local size_t depth_of_handlers{0 };
40+
41+ // This function ensures that the fork handler has been installed at least
42+ // once. It might be installed more than once, this is safe. As subsequent
43+ // calls would be ignored.
44+ static void ensure_init ()
45+ {
46+ #ifdef SNMALLOC_PTHREAD_ATFORK_WORKS
47+ static stl::Atomic<bool > initialised{false };
48+
49+ if (initialised.load (stl::memory_order_acquire))
50+ return ;
51+
52+ pthread_atfork (prefork, postfork_parent, postfork_child);
53+ initialised.store (true , stl::memory_order_release);
54+ #endif
55+ };
56+
57+ public:
58+ PreventFork ()
59+ {
60+ if (depth_of_prevention++ == 0 )
61+ {
62+ // Ensure that the system is initialised before we start.
63+ // Don't do this on nested Prevent calls.
64+ ensure_init ();
65+ while (true )
66+ {
67+ auto prev = threads_preventing_fork.fetch_add (2 );
68+ if (prev % 2 == 0 )
69+ break ;
70+
71+ threads_preventing_fork.fetch_sub (2 );
72+
73+ while ((threads_preventing_fork.load () % 2 ) == 1 )
74+ {
75+ Aal::pause ();
76+ }
77+ };
78+ }
79+ }
80+
81+ ~PreventFork ()
82+ {
83+ if (--depth_of_prevention == 0 )
84+ {
85+ threads_preventing_fork -= 2 ;
86+ }
87+ }
88+
89+ // The function that notifies new threads not to enter PreventFork regions
90+ // It waits until all threads are no longer in a PreventFork region before
91+ // returning.
92+ static void prefork ()
93+ {
94+ if (depth_of_handlers++ != 0 )
95+ return ;
96+
97+ if (depth_of_prevention != 0 )
98+ error (" Fork attempted while in PreventFork region." );
99+
100+ while (true )
101+ {
102+ auto current = threads_preventing_fork.load ();
103+ if (
104+ (current % 2 == 0 ) &&
105+ (threads_preventing_fork.compare_exchange_weak (current, current + 1 )))
106+ {
107+ break ;
108+ }
109+ Aal::pause ();
110+ };
111+
112+ while (threads_preventing_fork.load () != 1 )
113+ {
114+ Aal::pause ();
115+ }
116+
117+ // Finally set the flag that allows this thread to enter PreventFork
118+ // regions This is safe as the only other calls here are to other prefork
119+ // handlers.
120+ depth_of_prevention++;
121+ }
122+
123+ // Unsets the flag that allows threads to enter PreventFork regions
124+ // and for another thread to request a fork.
125+ static void postfork_child ()
126+ {
127+ // Count out the number of handlers that have been called, and
128+ // only perform on the last.
129+ if (--depth_of_handlers != 0 )
130+ return ;
131+
132+ // This thread is no longer preventing a fork, so decrement the counter.
133+ depth_of_prevention--;
134+
135+ // Allow other threads to allocate
136+ // There could have been threads spinning in the prefork handler having
137+ // optimistically increasing thread_preventing_fork by 2, but now the
138+ // threads do not exist due to the fork. So restart the counter in the
139+ // child.
140+ threads_preventing_fork = 0 ;
141+ }
142+
143+ // Unsets the flag that allows threads to enter PreventFork regions
144+ // and for another thread to request a fork.
145+ static void postfork_parent ()
146+ {
147+ // Count out the number of handlers that have been called, and
148+ // only perform on the last.
149+ if (--depth_of_handlers != 0 )
150+ return ;
151+
152+ // This thread is no longer preventing a fork, so decrement the counter.
153+ depth_of_prevention--;
154+
155+ // Allow other threads to allocate
156+ // Just remove the bit, and let the potential other threads in prefork
157+ // remove their counts.
158+ threads_preventing_fork--;
159+ }
160+ };
161+ } // namespace snmalloc
0 commit comments