diff --git a/CMakeLists.txt b/CMakeLists.txt index 501a2de3..999c8fea 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -21,6 +21,10 @@ FILE(APPEND THCUNN_generic_h.lua "]]") FILE(GLOB luasrc *.lua) +IF (ANDROID) + ADD_DEFINITIONS(-DTHC_MIN_MATH) +ENDIF() + ADD_SUBDIRECTORY(lib) INSTALL( diff --git a/THCUNN.lua b/THCUNN.lua index 6776a238..9dd90e5d 100644 --- a/THCUNN.lua +++ b/THCUNN.lua @@ -129,8 +129,10 @@ local function_names_generic = extract_function_names_generic(THCUNN_generic_h) THNN.kernels['torch.CudaTensor'] = THNN.bind(THCUNN.C, function_names_generic, 'Cuda', THCUNN.getState) torch.getmetatable('torch.CudaTensor').THNN = THNN.kernels['torch.CudaTensor'] -THNN.kernels['torch.CudaDoubleTensor'] = THNN.bind(THCUNN.C, function_names_generic, 'CudaDouble', THCUNN.getState) -torch.getmetatable('torch.CudaDoubleTensor').THNN = THNN.kernels['torch.CudaDoubleTensor'] +if not cutorch.minMath then + THNN.kernels['torch.CudaDoubleTensor'] = THNN.bind(THCUNN.C, function_names_generic, 'CudaDouble', THCUNN.getState) + torch.getmetatable('torch.CudaDoubleTensor').THNN = THNN.kernels['torch.CudaDoubleTensor'] +end if cutorch.hasHalf then -- in order to call 'half' functions from lua, convert real arguments from @@ -164,7 +166,10 @@ local function Module__converter(type) end end -rawset(torch.getmetatable('nn.Module'), 'cudaDouble', Module__converter('torch.CudaDoubleTensor')) +if not cutorch.minMath then + rawset(torch.getmetatable('nn.Module'), 'cudaDouble', Module__converter('torch.CudaDoubleTensor')) +end + if cutorch.hasHalf then rawset(torch.getmetatable('nn.Module'), 'cudaHalf', Module__converter('torch.CudaHalfTensor')) end diff --git a/init.lua b/init.lua index fd1e319b..6a1ceb6b 100644 --- a/init.lua +++ b/init.lua @@ -9,5 +9,4 @@ require('cunn.DataParallelTable') nn.Module._flattenTensorBuffer['torch.CudaTensor'] = torch.FloatTensor.new nn.Module._flattenTensorBuffer['torch.CudaDoubleTensor'] = torch.DoubleTensor.new --- FIXME: change this to torch.HalfTensor when available -nn.Module._flattenTensorBuffer['torch.CudaHalfTensor'] = torch.FloatTensor.new +nn.Module._flattenTensorBuffer['torch.CudaHalfTensor'] = torch.HalfTensor.new diff --git a/lib/THCUNN/SparseLinear.cu b/lib/THCUNN/SparseLinear.cu index a7ffa1e2..b1b756ac 100644 --- a/lib/THCUNN/SparseLinear.cu +++ b/lib/THCUNN/SparseLinear.cu @@ -82,6 +82,5 @@ void THNN_CudaHalfSparseLinear_updateParameters( #endif #include "generic/SparseLinear.cu" -#include "THCGenerateFloatType.h" -#include "generic/SparseLinear.cu" -#include "THCGenerateDoubleType.h" +#include "THCGenerateFloatTypes.h" + diff --git a/lib/THCUNN/THCGenerateFloatTypes.h b/lib/THCUNN/THCGenerateFloatTypes.h new file mode 100644 index 00000000..a197eed7 --- /dev/null +++ b/lib/THCUNN/THCGenerateFloatTypes.h @@ -0,0 +1,35 @@ +#ifndef THC_GENERIC_FILE +#error "You must define THC_GENERIC_FILE before including THGenerateFloatTypes.h" +#endif + +#define THCGenerateFloatTypes + +#define THCTypeIdxByte 1 +#define THCTypeIdxChar 2 +#define THCTypeIdxShort 3 +#define THCTypeIdxInt 4 +#define THCTypeIdxLong 5 +#define THCTypeIdxFloat 6 +#define THCTypeIdxDouble 7 +#define THCTypeIdxHalf 8 +#define THCTypeIdx_(T) TH_CONCAT_2(THCTypeIdx,T) + +# ifndef THC_MIN_MATH +# include "THCGenerateHalfType.h" +# include "THCGenerateDoubleType.h" +# endif + +#include "THCGenerateFloatType.h" + +#undef THCTypeIdxByte +#undef THCTypeIdxChar +#undef THCTypeIdxShort +#undef THCTypeIdxInt +#undef THCTypeIdxLong +#undef THCTypeIdxFloat +#undef THCTypeIdxDouble +#undef THCTypeIdxHalf +#undef THCTypeIdx_ + +#undef THCGenerateFloatTypes +#undef THC_GENERIC_FILE diff --git a/test.lua b/test.lua index c3ed9bb2..4a137232 100644 --- a/test.lua +++ b/test.lua @@ -11,20 +11,21 @@ local THC = ffi.os == 'Windows' and ffi.load('THC') or ffi.C --e.g.: th -lcunn -e "nn.testcuda{'Sigmoid_forward'}" local typenames = { - 'torch.CudaTensor', - 'torch.CudaDoubleTensor', + 'torch.CudaTensor' } local t2cpu = { - ['torch.CudaTensor'] = 'torch.FloatTensor', - ['torch.CudaDoubleTensor'] = 'torch.DoubleTensor', - + ['torch.CudaTensor'] = 'torch.FloatTensor' } local function checkHalf() - if cutorch.hasHalf then + if not cutorch.minMath then + if cutorch.hasHalf then table.insert(typenames, 'torch.CudaHalfTensor') - t2cpu['torch.CudaHalfTensor'] = 'torch.FloatTensor' + t2cpu['torch.CudaHalfTensor'] = 'torch.HalfTensor' + end + table.insert(typenames, 'torch.CudaDoubleTensor') + t2cpu['torch.CudaDoubleTensor'] = 'torch.DoubleTensor' end end