Skip to content

Commit f1b2c9d

Browse files
authored
Fix performance downgrade issue & update doc (#229)
For push function, we only need to make sure the instruction `st.global` will be executed after the while loop. Since there is a Write-After-Read hazard for `trigger.fst` (Check `this->triggers[curFifoHead % size].fst != 0` first then write value to `triggers[curFifoHead % size]`), we can expect the compiler and hardware can handle this situation correctly. Remove the `release.sys` there. BTW, `st.global.release.sys.v2.u64` will cause perf regression issue. Previous we use `st.global.release.cta.v2.u64`, but seems not necessary.
1 parent 351b95b commit f1b2c9d

File tree

2 files changed

+6
-4
lines changed

2 files changed

+6
-4
lines changed

docs/quickstart.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ $ make -j allgather_test_perf allreduce_test_perf
106106
For example, the following command runs the `allreduce5` algorithm with 8 GPUs starting from 3MB to 48MB messages, by doubling the message size in between. You can try different algorithms by changing the `-k 5` option to another value (e.g., `-k 3` runs `allreduce3`). Check all algorithms from the code: [allreduce_test.cu](https://github.com/microsoft/mscclpp/blob/main/test/mscclpp-test/allreduce_test.cu) and [allgather_test.cu](https://github.com/microsoft/mscclpp/blob/main/test/mscclpp-test/allgather_test.cu).
107107

108108
```bash
109-
$ mpirun --bind-to-numa -np 8 ./test/mscclpp-test/allreduce_test_perf -b 3m -e 48m -G 100 -n 100 -w 20 -f 2 -k 5
109+
$ mpirun --bind-to numa -np 8 ./test/mscclpp-test/allreduce_test_perf -b 3m -e 48m -G 100 -n 100 -w 20 -f 2 -k 5
110110
```
111111

112112
*NOTE: a few algorithms set a condition on the total data size, such as to be a multiple of 3. If the condition is unmet, the command will throw a regarding error.*

include/mscclpp/fifo_device.hpp

+5-3
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@ struct alignas(16) ProxyTrigger {
2525
uint64_t fst, snd;
2626
};
2727

28-
/// A concurrent FIFO where multiple device threads can push work elements and a single host proxy thread consumes them.
28+
/// A concurrent FIFO where multiple device threads (the number of threads should not exceed the fifo size) can push
29+
/// work elements and a single host proxy thread consumes them.
2930
///
3031
/// The FIFO has a head pointer allocated on the device which starts at 0 and goes up to 2^64-1, which is almost
3132
/// infinity. There are two copies of the tail, one on the device, @ref FifoDeviceHandle::tailReplica, and another on
@@ -64,9 +65,10 @@ struct FifoDeviceHandle {
6465

6566
ProxyTrigger* triggerPtr = &(this->triggers[curFifoHead % size]);
6667

67-
// store with memory order release so that the while loop does not go pass this.
68+
// There is a Write-After-Read hazard for the triggerPtr->fst. So the st instruction will not be executed
69+
// before the loop.
6870
#if defined(MSCCLPP_DEVICE_CUDA)
69-
asm volatile("st.global.release.sys.v2.u64 [%0], {%1,%2};" ::"l"(triggerPtr), "l"(trigger.fst), "l"(trigger.snd));
71+
asm volatile("st.global.relaxed.sys.v2.u64 [%0], {%1,%2};" ::"l"(triggerPtr), "l"(trigger.fst), "l"(trigger.snd));
7072
#else // !defined(MSCCLPP_DEVICE_CUDA)
7173
// TODO: both atomic and clang built-ins are buggy here
7274
triggerPtr->fst = trigger.fst;

0 commit comments

Comments
 (0)