Skip to content

Commit 444039c

Browse files
izdebyfacebook-github-bot
authored andcommitted
Bool tensor. Part 0: Boolean storage implementation (pytorch#16810)
Summary: This is the first commit from a series of planned changes in order to add boolean tensors to PyTorch. The whole plan looks like this: 0. Storage Implementation (this change) 1. Tensor Creation. 2. Tensor Conversions. 3. Tensor Indexing. 4. Tensor Operations. 5. Back compatibility related changes. This feature was requested by the community: pytorch#4764 pytorch#4219 pytorch#4288 **Change**: Added boolean type to the Storage class for CPU and CUDA backends. **Tested via**: 1. unit tests 2. running this: -> import torch -> torch.BoolStorage <class 'torch.BoolStorage'> -> torch.cuda.BoolStorage <class 'torch.cuda.BoolStorage'> Pull Request resolved: pytorch#16810 Reviewed By: gchanan Differential Revision: D14087246 Pulled By: izdeby fbshipit-source-id: 042642ced1cb0fd1bb6bff05f9ca871a5c54ee5e
1 parent e81878e commit 444039c

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

61 files changed

+323
-20
lines changed

aten/src/ATen/DLConvertor.cpp

+3
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,9 @@ static DLDataType getDLDataType(const Type& type) {
3737
case ScalarType::Half:
3838
dtype.code = DLDataTypeCode::kDLFloat;
3939
break;
40+
case ScalarType::Bool:
41+
dtype.code = DLDataTypeCode::kDLUInt;
42+
break;
4043
case ScalarType::ComplexHalf:
4144
throw std::logic_error("ComplexHalf is not supported by dlpack");
4245
case ScalarType::ComplexFloat:

aten/src/ATen/core/Type.h

+4
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ struct Generator;
4444
static inline void noop_deleter(void*) {}
4545

4646
enum class TypeID {
47+
CPUBool,
4748
CPUByte,
4849
CPUChar,
4950
CPUDouble,
@@ -52,13 +53,15 @@ enum class TypeID {
5253
CPULong,
5354
CPUShort,
5455
CPUHalf,
56+
SparseCPUBool,
5557
SparseCPUByte,
5658
SparseCPUChar,
5759
SparseCPUDouble,
5860
SparseCPUFloat,
5961
SparseCPUInt,
6062
SparseCPULong,
6163
SparseCPUShort,
64+
CUDABool,
6265
CUDAByte,
6366
CUDAChar,
6467
CUDADouble,
@@ -67,6 +70,7 @@ enum class TypeID {
6770
CUDALong,
6871
CUDAShort,
6972
CUDAHalf,
73+
SparseCUDABool,
7074
SparseCUDAByte,
7175
SparseCUDAChar,
7276
SparseCUDADouble,

aten/src/ATen/gen.py

+1
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,7 @@ def check_all_files_written(self):
178178

179179
# scalar_name, c_type, accreal, th_scalar_type, is_floating_type
180180
scalar_types = [
181+
('Bool', 'uint8_t', 'BoolAccrealNotDefined', 'uint8_t', False),
181182
('Byte', 'uint8_t', 'Long', 'uint8_t', False),
182183
('Char', 'int8_t', 'Long', 'int8_t', False),
183184
('Double', 'double', 'Double', 'double', True),

aten/src/TH/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ INSTALL(FILES
6464
THFilePrivate.h
6565
${CMAKE_CURRENT_BINARY_DIR}/THGeneral.h
6666
THGenerateAllTypes.h
67+
THGenerateBoolType.h
6768
THGenerateDoubleType.h
6869
THGenerateFloatType.h
6970
THGenerateHalfType.h

aten/src/TH/THFile.h

+2
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ TH_API size_t THFile_readInt(THFile *self, THIntStorage *storage);
4646
TH_API size_t THFile_readLong(THFile *self, THLongStorage *storage);
4747
TH_API size_t THFile_readFloat(THFile *self, THFloatStorage *storage);
4848
TH_API size_t THFile_readDouble(THFile *self, THDoubleStorage *storage);
49+
TH_API size_t THFile_readBool(THFile *self, THBoolStorage *storage);
4950

5051
TH_API size_t THFile_writeByte(THFile *self, THByteStorage *storage);
5152
TH_API size_t THFile_writeChar(THFile *self, THCharStorage *storage);
@@ -54,6 +55,7 @@ TH_API size_t THFile_writeInt(THFile *self, THIntStorage *storage);
5455
TH_API size_t THFile_writeLong(THFile *self, THLongStorage *storage);
5556
TH_API size_t THFile_writeFloat(THFile *self, THFloatStorage *storage);
5657
TH_API size_t THFile_writeDouble(THFile *self, THDoubleStorage *storage);
58+
TH_API size_t THFile_writeBool(THFile *self, THBoolStorage *storage);
5759

5860
/* raw */
5961
TH_API size_t THFile_readByteRaw(THFile *self, uint8_t *data, size_t n);

aten/src/TH/THGenerateBoolType.h

+22
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
#ifndef TH_GENERIC_FILE
2+
#error "You must define TH_GENERIC_FILE before including THGenerateBoolType.h"
3+
#endif
4+
5+
// TODO: define accreal type once the correct value is known.
6+
#define scalar_t bool
7+
#define ureal bool
8+
#define Real Bool
9+
#define TH_CONVERT_ACCREAL_TO_REAL(_val) (scalar_t)(_val)
10+
#define TH_REAL_IS_BOOL
11+
#line 1 TH_GENERIC_FILE
12+
#include TH_GENERIC_FILE
13+
#undef scalar_t
14+
#undef ureal
15+
#undef Real
16+
#undef TH_REAL_IS_BOOL
17+
#undef TH_CONVERT_REAL_TO_ACCREAL
18+
#undef TH_CONVERT_ACCREAL_TO_REAL
19+
20+
#ifndef THGenerateManyTypes
21+
#undef TH_GENERIC_FILE
22+
#endif

aten/src/TH/THStorageFunctions.cpp

+6
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,18 @@
99
#include <TH/generic/THStorage.cpp>
1010
#include <TH/THGenerateHalfType.h>
1111

12+
#include <TH/generic/THStorage.cpp>
13+
#include <TH/THGenerateBoolType.h>
14+
1215
#include <TH/generic/THStorageCopy.cpp>
1316
#include <TH/THGenerateAllTypes.h>
1417

1518
#include <TH/generic/THStorageCopy.cpp>
1619
#include <TH/THGenerateHalfType.h>
1720

21+
#include <TH/generic/THStorageCopy.cpp>
22+
#include <TH/THGenerateBoolType.h>
23+
1824
THStorage* THStorage_new(caffe2::TypeMeta data_type) {
1925
THStorage* storage = c10::make_intrusive<at::StorageImpl>(
2026
data_type,

aten/src/TH/THStorageFunctions.h

+6
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,17 @@
1111
#include <TH/generic/THStorage.h>
1212
#include <TH/THGenerateHalfType.h>
1313

14+
#include <TH/generic/THStorage.h>
15+
#include <TH/THGenerateBoolType.h>
16+
1417
#include <TH/generic/THStorageCopy.h>
1518
#include <TH/THGenerateAllTypes.h>
1619

1720
#include <TH/generic/THStorageCopy.h>
1821
#include <TH/THGenerateHalfType.h>
1922

23+
#include <TH/generic/THStorageCopy.h>
24+
#include <TH/THGenerateBoolType.h>
25+
2026
// This exists to have a data-type independent way of freeing (necessary for THPPointer).
2127
TH_API void THStorage_free(THStorage *storage);

aten/src/TH/THTensor.cpp

+3
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66
#include <TH/generic/THTensor.cpp>
77
#include <TH/THGenerateHalfType.h>
88

9+
#include <TH/generic/THTensor.cpp>
10+
#include <TH/THGenerateBoolType.h>
11+
912
#include <ATen/native/Resize.h>
1013

1114
#include <numeric>

aten/src/TH/THTensor.h

+3
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313
#include <TH/generic/THTensor.h>
1414
#include <TH/THGenerateHalfType.h>
1515

16+
#include <TH/generic/THTensor.h>
17+
#include <TH/THGenerateBoolType.h>
18+
1619
/* random numbers */
1720
#include <TH/THRandom.h>
1821
#include <TH/generic/THTensorRandom.h>

aten/src/TH/THTensor.hpp

+3
Original file line numberDiff line numberDiff line change
@@ -127,3 +127,6 @@ TH_CPP_API c10::optional<std::vector<int64_t>> THTensor_compute_stride(
127127

128128
#include <TH/generic/THTensor.hpp>
129129
#include <TH/THGenerateHalfType.h>
130+
131+
#include <TH/generic/THTensor.hpp>
132+
#include <TH/THGenerateBoolType.h>

aten/src/TH/generic/THStorage.h

+1
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
#define THShortStorage THStorage
3434
#define THIntStorage THStorage
3535
#define THLongStorage THStorage
36+
#define THBoolStorage THStorage
3637

3738
TH_API scalar_t* THStorage_(data)(const THStorage*);
3839
TH_API ptrdiff_t THStorage_(size)(const THStorage*);

aten/src/TH/generic/THStorageCopy.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -37,5 +37,6 @@ IMPLEMENT_THStorage_COPY(Long)
3737
IMPLEMENT_THStorage_COPY(Float)
3838
IMPLEMENT_THStorage_COPY(Double)
3939
IMPLEMENT_THStorage_COPY(Half)
40+
IMPLEMENT_THStorage_COPY(Bool)
4041

4142
#endif

aten/src/TH/generic/THStorageCopy.h

+1
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,6 @@ TH_API void THStorage_(copyLong)(THStorage *storage, struct THLongStorage *src);
1414
TH_API void THStorage_(copyFloat)(THStorage *storage, struct THFloatStorage *src);
1515
TH_API void THStorage_(copyDouble)(THStorage *storage, struct THDoubleStorage *src);
1616
TH_API void THStorage_(copyHalf)(THStorage *storage, struct THHalfStorage *src);
17+
TH_API void THStorage_(copyBool)(THStorage *storage, struct THBoolStorage *src);
1718

1819
#endif

aten/src/TH/generic/THTensor.h

+1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#define THShortTensor THTensor
1919
#define THIntTensor THTensor
2020
#define THLongTensor THTensor
21+
#define THBoolTensor THTensor
2122

2223
/**** access methods ****/
2324
TH_API THStorage* THTensor_(storage)(const THTensor *self);

aten/src/THC/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ INSTALL(FILES
7878
THCDeviceTensorUtils.cuh
7979
THCDeviceTensorUtils-inl.cuh
8080
THCGenerateAllTypes.h
81+
THCGenerateBoolType.h
8182
THCGenerateByteType.h
8283
THCGenerateCharType.h
8384
THCGenerateShortType.h

aten/src/THC/THCGenerateBoolType.h

+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
#ifndef THC_GENERIC_FILE
2+
#error "You must define THC_GENERIC_FILE before including THCGenerateBoolType.h"
3+
#endif
4+
5+
// TODO: define accreal type once the correct value is known.
6+
#define scalar_t bool
7+
#define ureal bool
8+
#define Real Bool
9+
#define CReal CudaBool
10+
#define THC_REAL_IS_BOOL
11+
#line 1 THC_GENERIC_FILE
12+
#include THC_GENERIC_FILE
13+
#undef scalar_t
14+
#undef ureal
15+
#undef Real
16+
#undef CReal
17+
#undef THC_REAL_IS_BOOL
18+
19+
#ifndef THCGenerateBoolType
20+
#undef THC_GENERIC_FILE
21+
#endif

aten/src/THC/THCStorage.cpp

+3
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@
88
#include <THC/generic/THCStorage.cpp>
99
#include <THC/THCGenerateAllTypes.h>
1010

11+
#include <THC/generic/THCStorage.cpp>
12+
#include <THC/THCGenerateBoolType.h>
13+
1114
#include <c10/util/intrusive_ptr.h>
1215

1316
void THCStorage_resize(THCState *state, THCStorage *self, ptrdiff_t size)

aten/src/THC/THCStorage.cu

+3
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,6 @@
1111

1212
#include <THC/generic/THCStorage.cu>
1313
#include <THC/THCGenerateAllTypes.h>
14+
15+
#include <THC/generic/THCStorage.cu>
16+
#include <THC/THCGenerateBoolType.h>

aten/src/THC/THCStorage.h

+3
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,7 @@
99
#include <THC/generic/THCStorage.h>
1010
#include <THC/THCGenerateAllTypes.h>
1111

12+
#include <THC/generic/THCStorage.h>
13+
#include <THC/THCGenerateBoolType.h>
14+
1215
#endif

aten/src/THC/THCStorageCopy.cpp

+3
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,6 @@
55

66
#include <THC/generic/THCStorageCopy.cpp>
77
#include <THC/THCGenerateAllTypes.h>
8+
9+
#include <THC/generic/THCStorageCopy.cpp>
10+
#include <THC/THCGenerateBoolType.h>

aten/src/THC/THCStorageCopy.cu

+3
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,6 @@
88

99
#include <THC/generic/THCStorageCopy.cu>
1010
#include <THC/THCGenerateAllTypes.h>
11+
12+
#include <THC/generic/THCStorageCopy.cu>
13+
#include <THC/THCGenerateBoolType.h>

aten/src/THC/THCStorageCopy.h

+3
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,7 @@
88
#include <THC/generic/THCStorageCopy.h>
99
#include <THC/THCGenerateAllTypes.h>
1010

11+
#include <THC/generic/THCStorageCopy.h>
12+
#include <THC/THCGenerateBoolType.h>
13+
1114
#endif

aten/src/THC/THCTensor.cpp

+5
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77
#include <THC/generic/THCTensor.cpp>
88
#include <THC/THCGenerateAllTypes.h>
99

10+
#include <THC/generic/THCTensor.cpp>
11+
#include <THC/THCGenerateBoolType.h>
12+
1013
#include <THC/THCTensorInfo.cuh>
1114

1215
#include <ATen/native/cuda/Resize.cuh>
@@ -61,6 +64,8 @@ THCTensor *THCTensor_new(THCState *state, caffe2::TypeMeta type_meta) {
6164
return THCudaTensor_new(state);
6265
case at::ScalarType::Double:
6366
return THCudaDoubleTensor_new(state);
67+
case at::ScalarType::Bool:
68+
return THCudaBoolTensor_new(state);
6469
default:
6570
AT_ERROR("unexpected ScalarType: ", toString(scalar_type));
6671
}

aten/src/THC/THCTensor.cu

+3
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,6 @@
33

44
#include <THC/generic/THCTensor.cu>
55
#include <THC/THCGenerateAllTypes.h>
6+
7+
#include <THC/generic/THCTensor.cu>
8+
#include <THC/THCGenerateBoolType.h>

aten/src/THC/THCTensor.h

+2
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,6 @@ typedef struct THC_CLASS THCDescBuff
1717
#include <THC/generic/THCTensor.h>
1818
#include <THC/THCGenerateAllTypes.h>
1919

20+
#include <THC/generic/THCTensor.h>
21+
#include <THC/THCGenerateBoolType.h>
2022
#endif

aten/src/THC/THCTensor.hpp

+3
Original file line numberDiff line numberDiff line change
@@ -56,3 +56,6 @@ THC_API bool THCTensor_maybeOverlappingIndices(THCState* state, const THCTensor*
5656

5757
#include <THC/generic/THCTensor.hpp>
5858
#include <THC/THCGenerateAllTypes.h>
59+
60+
#include <THC/generic/THCTensor.hpp>
61+
#include <THC/THCGenerateBoolType.h>

aten/src/THC/THCTensorCopy.cu

+14-4
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,26 @@
55
#include <type_traits>
66

77
// Copy operator for the pointwise apply kernel
8-
template <typename TypeDst, typename TypeSrc>
8+
template <typename T>
99
struct CopyOp {
10-
__device__ __forceinline__ void operator()(TypeDst* dst, TypeSrc* src) {
10+
__device__ __forceinline__ void operator()(T* dst, T* src) {
1111
#if __CUDA_ARCH__ >= 350
12-
*dst = ScalarConvert<TypeSrc, TypeDst>::to(__ldg(src));
12+
*dst = ScalarConvert<T, T>::to(__ldg(src));
1313
#else
14-
*dst = ScalarConvert<TypeSrc, TypeDst>::to(*src);
14+
*dst = ScalarConvert<T, T>::to(*src);
1515
#endif
1616
}
1717
};
1818

19+
template <>
20+
struct CopyOp <bool> {
21+
__device__ __forceinline__ void operator()(bool* dst, bool* src) {
22+
*dst = ScalarConvert<bool, bool>::to(*src);
23+
}
24+
};
25+
1926
#include <THC/generic/THCTensorCopy.cu>
2027
#include <THC/THCGenerateAllTypes.h>
28+
29+
#include <THC/generic/THCTensorCopy.cu>
30+
#include <THC/THCGenerateBoolType.h>

aten/src/THC/THCTensorCopy.h

+3
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,7 @@
99
#include <THC/generic/THCTensorCopy.h>
1010
#include <THC/THCGenerateAllTypes.h>
1111

12+
#include <THC/generic/THCTensorCopy.h>
13+
#include <THC/THCGenerateBoolType.h>
14+
1215
#endif

aten/src/THC/generic/THCStorage.h

+1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#define THCudaShortStorage THCStorage
1515
#define THCudaIntStorage THCStorage
1616
#define THCudaLongStorage THCStorage
17+
#define THCudaBoolStorage THCStorage
1718

1819
THC_API scalar_t* THCStorage_(data)(THCState *state, const THCStorage*);
1920
THC_API ptrdiff_t THCStorage_(size)(THCState *state, const THCStorage*);

aten/src/THC/generic/THCStorageCopy.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ TH_CUDA_STORAGE_IMPLEMENT_COPY(Long)
3333
TH_CUDA_STORAGE_IMPLEMENT_COPY(Float)
3434
TH_CUDA_STORAGE_IMPLEMENT_COPY(Half)
3535
TH_CUDA_STORAGE_IMPLEMENT_COPY(Double)
36+
TH_CUDA_STORAGE_IMPLEMENT_COPY(Bool)
3637

3738
void THStorage_(copyCuda)(THCState *state, THStorage *self, struct THCStorage *src)
3839
{
@@ -65,6 +66,7 @@ TH_CUDA_STORAGE_IMPLEMENT_COPYTO(Long)
6566
TH_CUDA_STORAGE_IMPLEMENT_COPYTO(Float)
6667
TH_CUDA_STORAGE_IMPLEMENT_COPYTO(Half)
6768
TH_CUDA_STORAGE_IMPLEMENT_COPYTO(Double)
69+
TH_CUDA_STORAGE_IMPLEMENT_COPYTO(Bool)
6870

6971
#undef TH_CUDA_STORAGE_IMPLEMENT_COPY
7072
#undef TH_CUDA_STORAGE_IMPLEMENT_COPYTO

aten/src/THC/generic/THCStorageCopy.cu

+1
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ THC_CUDA_STORAGE_IMPLEMENT_COPY(Long,Long)
2828
THC_CUDA_STORAGE_IMPLEMENT_COPY(Float,) // i.e. float
2929
THC_CUDA_STORAGE_IMPLEMENT_COPY(Double,Double)
3030
THC_CUDA_STORAGE_IMPLEMENT_COPY(Half,Half)
31+
THC_CUDA_STORAGE_IMPLEMENT_COPY(Bool,Bool)
3132

3233
#undef THC_CUDA_STORAGE_IMPLEMENT_COPY
3334

0 commit comments

Comments
 (0)