Skip to content

Commit c0f2559

Browse files
authored
Merge pull request #1 from electronicarts/awolfe/Gigi
Gigi source and latest code generation
2 parents f25aad4 + c212de8 commit c0f2559

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

64 files changed

+24872
-2221
lines changed

Demo/Demo.vcxproj

+29-10
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,14 @@
2222
<ConfigurationType>Application</ConfigurationType>
2323
<UseDebugLibraries>true</UseDebugLibraries>
2424
<PlatformToolset>v143</PlatformToolset>
25-
<CharacterSet>Unicode</CharacterSet>
25+
<CharacterSet>MultiByte</CharacterSet>
2626
</PropertyGroup>
2727
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'" Label="Configuration">
2828
<ConfigurationType>Application</ConfigurationType>
2929
<UseDebugLibraries>false</UseDebugLibraries>
3030
<PlatformToolset>v143</PlatformToolset>
3131
<WholeProgramOptimization>true</WholeProgramOptimization>
32-
<CharacterSet>Unicode</CharacterSet>
32+
<CharacterSet>MultiByte</CharacterSet>
3333
</PropertyGroup>
3434
<Import Project="$(VCTargetsPath)\Microsoft.Cpp.props" />
3535
<ImportGroup Label="ExtensionSettings">
@@ -45,11 +45,11 @@
4545
<PropertyGroup Label="UserMacros" />
4646
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">
4747
<LinkIncremental>true</LinkIncremental>
48-
<IncludePath>$(ProjectDir)imgui\;$(IncludePath)</IncludePath>
48+
<IncludePath>$(ProjectDir)imgui\;$(ProjectDir)mnist\;$(ProjectDir)mnist\DX12Utils\tinyexr\deps\miniz\;$(IncludePath)</IncludePath>
4949
</PropertyGroup>
5050
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'">
5151
<LinkIncremental>false</LinkIncremental>
52-
<IncludePath>$(ProjectDir)imgui\;$(IncludePath)</IncludePath>
52+
<IncludePath>$(ProjectDir)imgui\;$(ProjectDir)mnist\;$(ProjectDir)mnist\DX12Utils\tinyexr\deps\miniz\;$(IncludePath)</IncludePath>
5353
</PropertyGroup>
5454
<ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">
5555
<ClCompile>
@@ -58,6 +58,7 @@
5858
<PreprocessorDefinitions>_DEBUG;_CONSOLE;%(PreprocessorDefinitions)</PreprocessorDefinitions>
5959
<ConformanceMode>true</ConformanceMode>
6060
<ObjectFileName>$(IntDir)%(RelativeDir)</ObjectFileName>
61+
<LanguageStandard>stdcpp17</LanguageStandard>
6162
</ClCompile>
6263
<Link>
6364
<SubSystem>Windows</SubSystem>
@@ -75,6 +76,7 @@
7576
<PreprocessorDefinitions>NDEBUG;_CONSOLE;%(PreprocessorDefinitions)</PreprocessorDefinitions>
7677
<ConformanceMode>true</ConformanceMode>
7778
<ObjectFileName>$(IntDir)%(RelativeDir)</ObjectFileName>
79+
<LanguageStandard>stdcpp17</LanguageStandard>
7880
</ClCompile>
7981
<Link>
8082
<SubSystem>Windows</SubSystem>
@@ -86,9 +88,6 @@
8688
</Link>
8789
</ItemDefinitionGroup>
8890
<ItemGroup>
89-
<ClCompile Include="mnist\private\dxutils.cpp" />
90-
<ClCompile Include="mnist\private\shadercompiler.cpp" />
91-
<ClCompile Include="mnist\private\technique.cpp" />
9291
<ClCompile Include="imgui\backends\imgui_impl_dx12.cpp" />
9392
<ClCompile Include="imgui\backends\imgui_impl_win32.cpp" />
9493
<ClCompile Include="imgui\imgui.cpp" />
@@ -97,13 +96,33 @@
9796
<ClCompile Include="imgui\imgui_tables.cpp" />
9897
<ClCompile Include="imgui\imgui_widgets.cpp" />
9998
<ClCompile Include="main.cpp" />
99+
<ClCompile Include="mnist\DX12Utils\CompileShaders_dxc.cpp" />
100+
<ClCompile Include="mnist\DX12Utils\CompileShaders_fxc.cpp" />
101+
<ClCompile Include="mnist\DX12Utils\dxutils.cpp" />
102+
<ClCompile Include="mnist\DX12Utils\FileCache.cpp" />
103+
<ClCompile Include="mnist\DX12Utils\TextureCache.cpp" />
104+
<ClCompile Include="mnist\DX12Utils\tinyexr\deps\miniz\miniz.c" />
105+
<ClCompile Include="mnist\private\technique.cpp" />
100106
</ItemGroup>
101107
<ItemGroup>
102-
<ClInclude Include="..\stb\stb_image.h" />
103-
<ClInclude Include="mnist\private\dxutils.h" />
104-
<ClInclude Include="mnist\private\shadercompiler.h" />
108+
<ClInclude Include="mnist\DX12Utils\CompileShaders.h" />
109+
<ClInclude Include="mnist\DX12Utils\DelayedReleaseTracker.h" />
110+
<ClInclude Include="mnist\DX12Utils\dxutils.h" />
111+
<ClInclude Include="mnist\DX12Utils\FileCache.h" />
112+
<ClInclude Include="mnist\DX12Utils\HeapAllocationTracker.h" />
113+
<ClInclude Include="mnist\DX12Utils\logfn.h" />
114+
<ClInclude Include="mnist\DX12Utils\ParseCSV.h" />
115+
<ClInclude Include="mnist\DX12Utils\ReadbackHelper.h" />
116+
<ClInclude Include="mnist\DX12Utils\SRGB.h" />
117+
<ClInclude Include="mnist\DX12Utils\stb\stb_image.h" />
118+
<ClInclude Include="mnist\DX12Utils\stb\stb_image_write.h" />
119+
<ClInclude Include="mnist\DX12Utils\TextureCache.h" />
120+
<ClInclude Include="mnist\DX12Utils\tinyexr\deps\miniz\miniz.h" />
121+
<ClInclude Include="mnist\DX12Utils\tinyexr\tinyexr.h" />
105122
<ClInclude Include="mnist\private\technique.h" />
123+
<ClInclude Include="mnist\public\all.h" />
106124
<ClInclude Include="mnist\public\imgui.h" />
125+
<ClInclude Include="mnist\public\pythoninterface.h" />
107126
<ClInclude Include="mnist\public\technique.h" />
108127
</ItemGroup>
109128
<ItemGroup>

Demo/Demo.vcxproj.filters

+81-18
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,6 @@
22
<Project ToolsVersion="4.0" xmlns="http://schemas.microsoft.com/developer/msbuild/2003">
33
<ItemGroup>
44
<ClCompile Include="main.cpp" />
5-
<ClCompile Include="mnist\private\dxutils.cpp">
6-
<Filter>mnist\private</Filter>
7-
</ClCompile>
85
<ClCompile Include="imgui\imgui.cpp">
96
<Filter>imgui</Filter>
107
</ClCompile>
@@ -26,31 +23,85 @@
2623
<ClCompile Include="imgui\imgui_widgets.cpp">
2724
<Filter>imgui</Filter>
2825
</ClCompile>
29-
<ClCompile Include="mnist\private\shadercompiler.cpp">
30-
<Filter>mnist\private</Filter>
31-
</ClCompile>
3226
<ClCompile Include="mnist\private\technique.cpp">
3327
<Filter>mnist\private</Filter>
3428
</ClCompile>
29+
<ClCompile Include="mnist\DX12Utils\CompileShaders_dxc.cpp">
30+
<Filter>mnist\DX12Utils</Filter>
31+
</ClCompile>
32+
<ClCompile Include="mnist\DX12Utils\CompileShaders_fxc.cpp">
33+
<Filter>mnist\DX12Utils</Filter>
34+
</ClCompile>
35+
<ClCompile Include="mnist\DX12Utils\dxutils.cpp">
36+
<Filter>mnist\DX12Utils</Filter>
37+
</ClCompile>
38+
<ClCompile Include="mnist\DX12Utils\FileCache.cpp">
39+
<Filter>mnist\DX12Utils</Filter>
40+
</ClCompile>
41+
<ClCompile Include="mnist\DX12Utils\TextureCache.cpp">
42+
<Filter>mnist\DX12Utils</Filter>
43+
</ClCompile>
44+
<ClCompile Include="mnist\DX12Utils\tinyexr\deps\miniz\miniz.c">
45+
<Filter>mnist\DX12Utils\tinyexr\deps\miniz</Filter>
46+
</ClCompile>
3547
</ItemGroup>
3648
<ItemGroup>
37-
<ClInclude Include="mnist\private\dxutils.h">
38-
<Filter>mnist\private</Filter>
49+
<ClInclude Include="mnist\public\all.h">
50+
<Filter>mnist\public</Filter>
3951
</ClInclude>
40-
<ClInclude Include="mnist\private\shadercompiler.h">
41-
<Filter>mnist\private</Filter>
52+
<ClInclude Include="mnist\public\imgui.h">
53+
<Filter>mnist\public</Filter>
4254
</ClInclude>
43-
<ClInclude Include="mnist\private\technique.h">
44-
<Filter>mnist\private</Filter>
55+
<ClInclude Include="mnist\public\pythoninterface.h">
56+
<Filter>mnist\public</Filter>
4557
</ClInclude>
4658
<ClInclude Include="mnist\public\technique.h">
4759
<Filter>mnist\public</Filter>
4860
</ClInclude>
49-
<ClInclude Include="mnist\public\imgui.h">
50-
<Filter>mnist\public</Filter>
61+
<ClInclude Include="mnist\private\technique.h">
62+
<Filter>mnist\private</Filter>
63+
</ClInclude>
64+
<ClInclude Include="mnist\DX12Utils\CompileShaders.h">
65+
<Filter>mnist\DX12Utils</Filter>
66+
</ClInclude>
67+
<ClInclude Include="mnist\DX12Utils\DelayedReleaseTracker.h">
68+
<Filter>mnist\DX12Utils</Filter>
69+
</ClInclude>
70+
<ClInclude Include="mnist\DX12Utils\dxutils.h">
71+
<Filter>mnist\DX12Utils</Filter>
72+
</ClInclude>
73+
<ClInclude Include="mnist\DX12Utils\FileCache.h">
74+
<Filter>mnist\DX12Utils</Filter>
75+
</ClInclude>
76+
<ClInclude Include="mnist\DX12Utils\HeapAllocationTracker.h">
77+
<Filter>mnist\DX12Utils</Filter>
78+
</ClInclude>
79+
<ClInclude Include="mnist\DX12Utils\logfn.h">
80+
<Filter>mnist\DX12Utils</Filter>
81+
</ClInclude>
82+
<ClInclude Include="mnist\DX12Utils\ParseCSV.h">
83+
<Filter>mnist\DX12Utils</Filter>
84+
</ClInclude>
85+
<ClInclude Include="mnist\DX12Utils\ReadbackHelper.h">
86+
<Filter>mnist\DX12Utils</Filter>
5187
</ClInclude>
52-
<ClInclude Include="..\stb\stb_image.h">
53-
<Filter>stb</Filter>
88+
<ClInclude Include="mnist\DX12Utils\SRGB.h">
89+
<Filter>mnist\DX12Utils</Filter>
90+
</ClInclude>
91+
<ClInclude Include="mnist\DX12Utils\TextureCache.h">
92+
<Filter>mnist\DX12Utils</Filter>
93+
</ClInclude>
94+
<ClInclude Include="mnist\DX12Utils\stb\stb_image.h">
95+
<Filter>mnist\DX12Utils\stb</Filter>
96+
</ClInclude>
97+
<ClInclude Include="mnist\DX12Utils\stb\stb_image_write.h">
98+
<Filter>mnist\DX12Utils\stb</Filter>
99+
</ClInclude>
100+
<ClInclude Include="mnist\DX12Utils\tinyexr\tinyexr.h">
101+
<Filter>mnist\DX12Utils\tinyexr</Filter>
102+
</ClInclude>
103+
<ClInclude Include="mnist\DX12Utils\tinyexr\deps\miniz\miniz.h">
104+
<Filter>mnist\DX12Utils\tinyexr\deps\miniz</Filter>
54105
</ClInclude>
55106
</ItemGroup>
56107
<ItemGroup>
@@ -69,8 +120,20 @@
69120
<Filter Include="mnist\shaders">
70121
<UniqueIdentifier>{5b06f838-eb45-4e0b-a445-ea8638ccc919}</UniqueIdentifier>
71122
</Filter>
72-
<Filter Include="stb">
73-
<UniqueIdentifier>{3ee28b1b-40ee-4855-acf6-5fbec8bbdfa7}</UniqueIdentifier>
123+
<Filter Include="mnist\DX12Utils">
124+
<UniqueIdentifier>{ff0e1c52-a8a7-47b5-9e9c-a7c14673cc8c}</UniqueIdentifier>
125+
</Filter>
126+
<Filter Include="mnist\DX12Utils\stb">
127+
<UniqueIdentifier>{c49b5917-47d2-442d-86e3-7429ebfd0ac4}</UniqueIdentifier>
128+
</Filter>
129+
<Filter Include="mnist\DX12Utils\tinyexr">
130+
<UniqueIdentifier>{48ef629a-d601-4321-bb81-5486635b3e77}</UniqueIdentifier>
131+
</Filter>
132+
<Filter Include="mnist\DX12Utils\tinyexr\deps">
133+
<UniqueIdentifier>{066e23ac-5839-4b47-9a06-ee13858ba555}</UniqueIdentifier>
134+
</Filter>
135+
<Filter Include="mnist\DX12Utils\tinyexr\deps\miniz">
136+
<UniqueIdentifier>{92005e30-4bbe-4bea-96fb-0fd1dc33a900}</UniqueIdentifier>
74137
</Filter>
75138
</ItemGroup>
76139
<ItemGroup>

Demo/dxcompiler.dll

14 MB
Binary file not shown.

Demo/dxil.dll

1.44 MB
Binary file not shown.

Demo/main.cpp

+7-53
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,14 @@
1818
#include "mnist/public/technique.h"
1919
#include "mnist/public/imgui.h"
2020

21-
#define STB_IMAGE_IMPLEMENTATION
22-
#include "../stb/stb_image.h"
21+
#include "mnist/DX12Utils/stb/stb_image.h"
2322

2423
// Note: this being true can cause crashes in nsight (nsight says so on startup)
2524
#define BREAK_ON_DX12_ERROR() _DEBUG
2625

2726
static unsigned int c_width = 1280;
2827
static unsigned int c_height = 1000;
29-
static const wchar_t* c_windowTitle = L"MNIST Neural Network Demo";
28+
static const char* c_windowTitle = "MNIST Neural Network Demo";
3029
static const bool g_useWarpDevice = false;
3130
static const UINT FrameCount = 2;
3231
static const bool c_enableGPUBasedValidation = false;
@@ -235,7 +234,7 @@ struct DX12Data
235234
if (!m_mnist->m_input.buffer_NN_Weights)
236235
{
237236
std::vector<char> weights = LoadBinaryFileIntoMemory("mnist/assets/Backprop_Weights.bin");
238-
m_mnist->m_input.buffer_NN_Weights = m_mnist->CreateManagedBuffer(m_device, (unsigned int)weights.size(), m_mnist->m_input.c_buffer_NN_Weights_flags, D3D12_RESOURCE_STATE_COMMON, D3D12_HEAP_TYPE_DEFAULT, m_commandList, weights.data(), L"MNIST NNWeights");
237+
m_mnist->m_input.buffer_NN_Weights = m_mnist->CreateManagedBuffer(m_device, m_commandList, m_mnist->m_input.c_buffer_NN_Weights_flags, weights.data(), (unsigned int)weights.size(), L"MNIST NNWeights", D3D12_RESOURCE_STATE_COMMON);
239238
m_mnist->m_input.buffer_NN_Weights_format = DXGI_FORMAT_R32_FLOAT;
240239
m_mnist->m_input.buffer_NN_Weights_stride = 0;
241240
m_mnist->m_input.buffer_NN_Weights_count = (unsigned int)(weights.size() / sizeof(float));
@@ -246,7 +245,7 @@ struct DX12Data
246245
if (!m_mnist->m_input.texture_Imported_Image)
247246
{
248247
static const unsigned int c_size[2] = { 28, 28 };
249-
m_mnist->m_input.texture_Imported_Image = m_mnist->CreateManagedTexture2D(m_device, c_size, DXGI_FORMAT_R8_UNORM, m_mnist->m_input.texture_Imported_Image_flags, D3D12_RESOURCE_STATE_COMMON, m_commandList, nullptr, 0, L"MNIST Imported Image");
248+
m_mnist->m_input.texture_Imported_Image = m_mnist->CreateManagedTexture(m_device, m_commandList, m_mnist->m_input.texture_Imported_Image_flags, DXGI_FORMAT_R8_UNORM, c_size, 1, DX12Utils::ResourceType::Texture2D, nullptr, L"MNIST Imported Image", D3D12_RESOURCE_STATE_COMMON);
250249
m_mnist->m_input.texture_Imported_Image_size[0] = 28;
251250
m_mnist->m_input.texture_Imported_Image_size[1] = 28;
252251
m_mnist->m_input.texture_Imported_Image_size[2] = 1;
@@ -635,58 +634,14 @@ struct DX12Data
635634
}
636635
}
637636

638-
template <typename T>
639-
static std::string GetAssetPath();
640-
641-
template <>
642-
static std::string GetAssetPath<mnist::LoadTextureData>()
643-
{
644-
return "mnist/assets/";
645-
}
646-
647-
template <typename T>
648-
static bool GigiLoadTexture(T& data)
649-
{
650-
std::string fullFileName = GetAssetPath<T>() + data.fileName;
651-
652-
std::string extension;
653-
size_t extensionStart = fullFileName.find_last_of(".");
654-
if (extensionStart != std::string::npos)
655-
extension = fullFileName.substr(extensionStart);
656-
if (extension == ".hdr")
657-
{
658-
int c;
659-
float* pixels = stbi_loadf(fullFileName.c_str(), &data.width, &data.height, &c, data.numChannels);
660-
if (!pixels)
661-
return false;
662-
663-
data.pixelsF32.resize(data.width * data.height * data.numChannels);
664-
memcpy(data.pixelsF32.data(), pixels, data.pixelsF32.size() * sizeof(float));
665-
stbi_image_free(pixels);
666-
}
667-
else
668-
{
669-
int c;
670-
unsigned char* pixels = stbi_load(fullFileName.c_str(), &data.width, &data.height, &c, data.numChannels);
671-
if (!pixels)
672-
return false;
673-
674-
data.pixelsU8.resize(data.width * data.height * data.numChannels);
675-
memcpy(data.pixelsU8.data(), pixels, data.pixelsU8.size());
676-
stbi_image_free(pixels);
677-
}
678-
679-
return true;
680-
}
681-
682-
static void GigiLogFn(int level, const char* msg, ...)
637+
static void GigiLogFn(LogLevel level, const char* msg, ...)
683638
{
684639
static std::vector<char> buffer(40960);
685640
va_list args;
686641
va_start(args, msg);
687642
vsprintf_s(buffer.data(), buffer.size(), msg, args);
688643
va_end(args);
689-
if (level >= 2)
644+
if (level == LogLevel::Error)
690645
Assert(false, "Gigi: %s", buffer.data());
691646
}
692647

@@ -711,7 +666,6 @@ struct DX12Data
711666

712667
// Set the logging function, perf marker functions, and shader locations, and create the technique contexts
713668
mnist::Context::LogFn = &GigiLogFn;
714-
mnist::Context::LoadTextureFn = &GigiLoadTexture<mnist::LoadTextureData>;
715669
mnist::Context::s_techniqueLocation = L"mnist/";
716670
m_mnist = mnist::CreateContext(m_device);
717671
Assert(m_mnist != nullptr, "Could not create mnist context");
@@ -834,7 +788,7 @@ void InitializeGraphics()
834788
windowClass.lpfnWndProc = WindowProc;
835789
windowClass.hInstance = s_hInstance;
836790
windowClass.hCursor = LoadCursor(NULL, IDC_ARROW);
837-
windowClass.lpszClassName = L"MNISTNN";
791+
windowClass.lpszClassName = "MNISTNN";
838792
RegisterClassEx(&windowClass);
839793

840794
RECT windowRect = { 0, 0, (LONG)c_width, (LONG)c_height };

0 commit comments

Comments
 (0)