Skip to content

Commit 998debe

Browse files
authored
Merge pull request #1183 from ldorau/Fix_segfault_in_cu_memory_provider_get_last_native_error
Fix segfault in `cu_memory_provider_get_last_native_error()`
2 parents e3cb666 + 7173cc5 commit 998debe

File tree

1 file changed

+29
-10
lines changed

1 file changed

+29
-10
lines changed

src/provider/provider_cuda.c

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -542,22 +542,41 @@ static void cu_memory_provider_get_last_native_error(void *provider,
542542
return;
543543
}
544544

545-
const char *error_name = 0;
546-
const char *error_string = 0;
547-
g_cu_ops.cuGetErrorName(TLS_last_native_error.native_error, &error_name);
548-
g_cu_ops.cuGetErrorString(TLS_last_native_error.native_error,
549-
&error_string);
550-
545+
CUresult result;
551546
size_t buf_size = 0;
552-
strncpy(TLS_last_native_error.msg_buff, error_name, TLS_MSG_BUF_LEN - 1);
553-
buf_size = strlen(TLS_last_native_error.msg_buff);
547+
const char *error_name = NULL;
548+
const char *error_string = NULL;
549+
550+
// If the error code is not recognized,
551+
// CUDA_ERROR_INVALID_VALUE will be returned
552+
// and error_name will be set to the NULL address.
553+
result = g_cu_ops.cuGetErrorName(TLS_last_native_error.native_error,
554+
&error_name);
555+
if (result == CUDA_SUCCESS && error_name != NULL) {
556+
strncpy(TLS_last_native_error.msg_buff, error_name,
557+
TLS_MSG_BUF_LEN - 1);
558+
} else {
559+
strncpy(TLS_last_native_error.msg_buff, "cuGetErrorName() failed",
560+
TLS_MSG_BUF_LEN - 1);
561+
}
554562

563+
buf_size = strlen(TLS_last_native_error.msg_buff);
555564
strncat(TLS_last_native_error.msg_buff, " - ",
556565
TLS_MSG_BUF_LEN - buf_size - 1);
557566
buf_size = strlen(TLS_last_native_error.msg_buff);
558567

559-
strncat(TLS_last_native_error.msg_buff, error_string,
560-
TLS_MSG_BUF_LEN - buf_size - 1);
568+
// If the error code is not recognized,
569+
// CUDA_ERROR_INVALID_VALUE will be returned
570+
// and error_string will be set to the NULL address.
571+
result = g_cu_ops.cuGetErrorString(TLS_last_native_error.native_error,
572+
&error_string);
573+
if (result == CUDA_SUCCESS && error_string != NULL) {
574+
strncat(TLS_last_native_error.msg_buff, error_string,
575+
TLS_MSG_BUF_LEN - buf_size - 1);
576+
} else {
577+
strncat(TLS_last_native_error.msg_buff, "cuGetErrorString() failed",
578+
TLS_MSG_BUF_LEN - buf_size - 1);
579+
}
561580

562581
*pError = TLS_last_native_error.native_error;
563582
*ppMessage = TLS_last_native_error.msg_buff;

0 commit comments

Comments
 (0)