18
18
#include " mnist/public/technique.h"
19
19
#include " mnist/public/imgui.h"
20
20
21
- #define STB_IMAGE_IMPLEMENTATION
22
- #include " ../stb/stb_image.h"
21
+ #include " mnist/DX12Utils/stb/stb_image.h"
23
22
24
23
// Note: this being true can cause crashes in nsight (nsight says so on startup)
25
24
#define BREAK_ON_DX12_ERROR () _DEBUG
26
25
27
26
static unsigned int c_width = 1280 ;
28
27
static unsigned int c_height = 1000 ;
29
- static const wchar_t * c_windowTitle = L " MNIST Neural Network Demo" ;
28
+ static const char * c_windowTitle = " MNIST Neural Network Demo" ;
30
29
static const bool g_useWarpDevice = false ;
31
30
static const UINT FrameCount = 2 ;
32
31
static const bool c_enableGPUBasedValidation = false ;
@@ -235,7 +234,7 @@ struct DX12Data
235
234
if (!m_mnist->m_input .buffer_NN_Weights )
236
235
{
237
236
std::vector<char > weights = LoadBinaryFileIntoMemory (" mnist/assets/Backprop_Weights.bin" );
238
- m_mnist->m_input .buffer_NN_Weights = m_mnist->CreateManagedBuffer (m_device, ( unsigned int )weights. size () , m_mnist->m_input .c_buffer_NN_Weights_flags , D3D12_RESOURCE_STATE_COMMON, D3D12_HEAP_TYPE_DEFAULT, m_commandList, weights.data (), L" MNIST NNWeights" );
237
+ m_mnist->m_input .buffer_NN_Weights = m_mnist->CreateManagedBuffer (m_device, m_commandList , m_mnist->m_input .c_buffer_NN_Weights_flags , weights. data (), ( unsigned int ) weights.size (), L" MNIST NNWeights" , D3D12_RESOURCE_STATE_COMMON );
239
238
m_mnist->m_input .buffer_NN_Weights_format = DXGI_FORMAT_R32_FLOAT;
240
239
m_mnist->m_input .buffer_NN_Weights_stride = 0 ;
241
240
m_mnist->m_input .buffer_NN_Weights_count = (unsigned int )(weights.size () / sizeof (float ));
@@ -246,7 +245,7 @@ struct DX12Data
246
245
if (!m_mnist->m_input .texture_Imported_Image )
247
246
{
248
247
static const unsigned int c_size[2 ] = { 28 , 28 };
249
- m_mnist->m_input .texture_Imported_Image = m_mnist->CreateManagedTexture2D (m_device, c_size, DXGI_FORMAT_R8_UNORM, m_mnist->m_input .texture_Imported_Image_flags , D3D12_RESOURCE_STATE_COMMON, m_commandList, nullptr , 0 , L" MNIST Imported Image" );
248
+ m_mnist->m_input .texture_Imported_Image = m_mnist->CreateManagedTexture (m_device, m_commandList, m_mnist->m_input .texture_Imported_Image_flags , DXGI_FORMAT_R8_UNORM, c_size, 1 , DX12Utils::ResourceType::Texture2D, nullptr , L" MNIST Imported Image" , D3D12_RESOURCE_STATE_COMMON );
250
249
m_mnist->m_input .texture_Imported_Image_size [0 ] = 28 ;
251
250
m_mnist->m_input .texture_Imported_Image_size [1 ] = 28 ;
252
251
m_mnist->m_input .texture_Imported_Image_size [2 ] = 1 ;
@@ -635,58 +634,14 @@ struct DX12Data
635
634
}
636
635
}
637
636
638
- template <typename T>
639
- static std::string GetAssetPath ();
640
-
641
- template <>
642
- static std::string GetAssetPath<mnist::LoadTextureData>()
643
- {
644
- return " mnist/assets/" ;
645
- }
646
-
647
- template <typename T>
648
- static bool GigiLoadTexture (T& data)
649
- {
650
- std::string fullFileName = GetAssetPath<T>() + data.fileName ;
651
-
652
- std::string extension;
653
- size_t extensionStart = fullFileName.find_last_of (" ." );
654
- if (extensionStart != std::string::npos)
655
- extension = fullFileName.substr (extensionStart);
656
- if (extension == " .hdr" )
657
- {
658
- int c;
659
- float * pixels = stbi_loadf (fullFileName.c_str (), &data.width , &data.height , &c, data.numChannels );
660
- if (!pixels)
661
- return false ;
662
-
663
- data.pixelsF32 .resize (data.width * data.height * data.numChannels );
664
- memcpy (data.pixelsF32 .data (), pixels, data.pixelsF32 .size () * sizeof (float ));
665
- stbi_image_free (pixels);
666
- }
667
- else
668
- {
669
- int c;
670
- unsigned char * pixels = stbi_load (fullFileName.c_str (), &data.width , &data.height , &c, data.numChannels );
671
- if (!pixels)
672
- return false ;
673
-
674
- data.pixelsU8 .resize (data.width * data.height * data.numChannels );
675
- memcpy (data.pixelsU8 .data (), pixels, data.pixelsU8 .size ());
676
- stbi_image_free (pixels);
677
- }
678
-
679
- return true ;
680
- }
681
-
682
- static void GigiLogFn (int level, const char * msg, ...)
637
+ static void GigiLogFn (LogLevel level, const char * msg, ...)
683
638
{
684
639
static std::vector<char > buffer (40960 );
685
640
va_list args;
686
641
va_start (args, msg);
687
642
vsprintf_s (buffer.data (), buffer.size (), msg, args);
688
643
va_end (args);
689
- if (level >= 2 )
644
+ if (level == LogLevel::Error )
690
645
Assert (false , " Gigi: %s" , buffer.data ());
691
646
}
692
647
@@ -711,7 +666,6 @@ struct DX12Data
711
666
712
667
// Set the logging function, perf marker functions, and shader locations, and create the technique contexts
713
668
mnist::Context::LogFn = &GigiLogFn;
714
- mnist::Context::LoadTextureFn = &GigiLoadTexture<mnist::LoadTextureData>;
715
669
mnist::Context::s_techniqueLocation = L" mnist/" ;
716
670
m_mnist = mnist::CreateContext (m_device);
717
671
Assert (m_mnist != nullptr , " Could not create mnist context" );
@@ -834,7 +788,7 @@ void InitializeGraphics()
834
788
windowClass.lpfnWndProc = WindowProc;
835
789
windowClass.hInstance = s_hInstance;
836
790
windowClass.hCursor = LoadCursor (NULL , IDC_ARROW);
837
- windowClass.lpszClassName = L " MNISTNN" ;
791
+ windowClass.lpszClassName = " MNISTNN" ;
838
792
RegisterClassEx (&windowClass);
839
793
840
794
RECT windowRect = { 0 , 0 , (LONG)c_width, (LONG)c_height };
0 commit comments