1
1
#pragma once
2
2
3
- #include " debug.hpp"
4
3
#include < cuda_runtime.h>
4
+ #include < nvfunctional>
5
5
#include < version>
6
6
#include < cstddef>
7
7
#include < cstdio>
8
8
#include < cstdlib>
9
9
#include < cstdarg>
10
+ #include < cstdarg>
10
11
#include < memory>
11
12
#include < new>
12
13
#include < string>
@@ -187,6 +188,8 @@ private:
187
188
188
189
public:
189
190
CudaMemPool (std::nullptr_t ) noexcept {}
191
+ CudaMemPool (CudaMemPool &&) = default ;
192
+ CudaMemPool &operator =(CudaMemPool &&) = default ;
190
193
191
194
struct Builder {
192
195
private:
@@ -259,12 +262,17 @@ private:
259
262
260
263
public:
261
264
CudaEvent (std::nullptr_t ) noexcept {}
265
+ CudaEvent (CudaEvent &&) = default ;
266
+ CudaEvent &operator =(CudaEvent &&) = default ;
262
267
263
268
struct Builder {
264
269
private:
265
270
int flags = cudaEventDefault;
266
271
267
272
public:
273
+ Builder () = default ;
274
+ explicit Builder (int flags) noexcept : flags(flags) {}
275
+
268
276
Builder &withBlockingSync (bool blockingSync = true ) noexcept {
269
277
if (blockingSync) {
270
278
flags |= cudaEventBlockingSync;
@@ -303,24 +311,28 @@ public:
303
311
CHECK_CUDA (cudaEventSynchronize (*this ));
304
312
}
305
313
306
- bool joinReady () const {
314
+ bool poll () const {
307
315
cudaError_t res = cudaEventQuery (*this );
308
316
if (res == cudaSuccess) {
309
317
return true ;
310
318
}
311
319
if (res == cudaErrorNotReady) {
312
320
return false ;
313
321
}
314
- CHECK_CUDA (res);
322
+ CHECK_CUDA (res /* cudaEventQuery */ );
315
323
return false ;
316
324
}
317
325
318
326
float elapsedMillis (CudaEvent const &event) const {
319
327
float result;
320
- CHECK_CUDA (cudaEventElapsedTime (&result, * this , event ));
328
+ CHECK_CUDA (cudaEventElapsedTime (&result, event, * this ));
321
329
return result;
322
330
}
323
331
332
+ float operator -(CudaEvent const &event) const {
333
+ return elapsedMillis (event);
334
+ }
335
+
324
336
~CudaEvent () {
325
337
if (*this ) {
326
338
CHECK_CUDA (cudaEventDestroy (*this ));
@@ -335,12 +347,17 @@ private:
335
347
336
348
public:
337
349
CudaStream (std::nullptr_t ) noexcept {}
350
+ CudaStream (CudaStream &&) = default ;
351
+ CudaStream &operator =(CudaStream &&) = default ;
338
352
339
353
struct Builder {
340
354
private:
341
355
int flags = cudaStreamDefault;
342
356
343
357
public:
358
+ Builder () = default ;
359
+ explicit Builder (int flags) noexcept : flags(flags) {}
360
+
344
361
Builder &withNonBlocking (bool nonBlocking = true ) noexcept {
345
362
if (nonBlocking) {
346
363
flags |= cudaStreamNonBlocking;
@@ -357,10 +374,14 @@ public:
357
374
}
358
375
};
359
376
360
- static CudaStream nullStream () noexcept {
377
+ static CudaStream defaultStream () noexcept {
361
378
return CudaStream (nullptr );
362
379
}
363
380
381
+ static CudaStream perThreadStream () noexcept {
382
+ return CudaStream (cudaStreamPerThread);
383
+ }
384
+
364
385
void copy (void *dst, void *src, size_t size, cudaMemcpyKind kind) const {
365
386
CHECK_CUDA (cudaMemcpyAsync (dst, src, size, kind, *this ));
366
387
}
@@ -381,11 +402,17 @@ public:
381
402
copy (dst, src, size, cudaMemcpyHostToHost);
382
403
}
383
404
384
- void record (CudaEvent const &event) const {
405
+ void recordEvent (CudaEvent const &event) const {
385
406
CHECK_CUDA (cudaEventRecord (event, *this ));
386
407
}
387
408
388
- void wait (CudaEvent const &event,
409
+ CudaEvent recordEvent () const {
410
+ CudaEvent event = CudaEvent::Builder ().build ();
411
+ recordEvent (event);
412
+ return event;
413
+ }
414
+
415
+ void waitEvent (CudaEvent const &event,
389
416
unsigned int flags = cudaEventWaitDefault) const {
390
417
CHECK_CUDA (cudaStreamWaitEvent (*this , event, flags));
391
418
}
@@ -403,22 +430,23 @@ public:
403
430
auto userData = std::make_unique<Func>();
404
431
cudaStreamCallback_t callback = [](cudaStream_t stream,
405
432
cudaError_t status, void *userData) {
433
+ CHECK_CUDA (status /* joinAsync cudaStreamCallback */ );
406
434
std::unique_ptr<Func> func (static_cast <Func *>(userData));
407
- (*func)(stream, status );
435
+ (*func)();
408
436
};
409
437
joinAsync (callback, userData.get ());
410
438
userData.release ();
411
439
}
412
440
413
- bool joinReady () const {
441
+ bool poll () const {
414
442
cudaError_t res = cudaStreamQuery (*this );
415
443
if (res == cudaSuccess) {
416
444
return true ;
417
445
}
418
446
if (res == cudaErrorNotReady) {
419
447
return false ;
420
448
}
421
- CHECK_CUDA (res);
449
+ CHECK_CUDA (res /* cudaStreamQuery */ );
422
450
return false ;
423
451
}
424
452
@@ -428,7 +456,7 @@ public:
428
456
}
429
457
430
458
~CudaStream () {
431
- if (*this ) {
459
+ if (*this && * this != cudaStreamPerThread ) {
432
460
CHECK_CUDA (cudaStreamDestroy (*this ));
433
461
}
434
462
}
@@ -522,8 +550,8 @@ struct CudaAllocator : private Arena {
522
550
};
523
551
};
524
552
525
- template <class T >
526
- using CudaVector = std::vector<T, CudaAllocator<T>>;
553
+ template <class T , class Arena = CudaManagedArena >
554
+ using CudaVector = std::vector<T, CudaAllocator<T, Arena >>;
527
555
528
556
#if defined(__clang__) && defined(__CUDACC__) && defined(__GLIBCXX__)
529
557
__host__ __device__ static void printf (const char *fmt, ...) {
0 commit comments