Skip to content

Commit 153a4d0

Browse files
committed
Guard against null for objects passed by value
Fixes #1228 Signed-off-by: Dimitar Dobrev <[email protected]>
1 parent 3cd7fde commit 153a4d0

File tree

5 files changed

+62
-36
lines changed

5 files changed

+62
-36
lines changed

Diff for: src/Generator/Generators/CLI/CLIMarshal.cs

+7
Original file line numberDiff line numberDiff line change
@@ -679,6 +679,13 @@ private void MarshalRefClass(Class @class)
679679

680680
if (Context.Parameter.Type.IsReference())
681681
VarPrefix.Write("&");
682+
else
683+
{
684+
Context.Before.WriteLine($"if (ReferenceEquals({Context.Parameter.Name}, nullptr))");
685+
Context.Before.WriteLineIndent(
686+
$@"throw gcnew ::System::ArgumentNullException(""{
687+
Context.Parameter.Name}"", ""Cannot be null because it is passed by value."");");
688+
}
682689
}
683690

684691
if (method != null

Diff for: src/Generator/Generators/CSharp/CSharpMarshal.cs

+47-32
Original file line numberDiff line numberDiff line change
@@ -701,46 +701,61 @@ private void MarshalRefClass(Class @class)
701701
@interface.IsInterface)
702702
paramInstance = $"{param}.__PointerTo{@interface.OriginalClass.Name}";
703703
else
704-
paramInstance = $@"{param}.{Helpers.InstanceIdentifier}";
705-
if (type.IsAddress())
704+
paramInstance = $"{param}.{Helpers.InstanceIdentifier}";
705+
706+
if (!type.IsAddress())
707+
{
708+
Context.Before.WriteLine($"if (ReferenceEquals({Context.Parameter.Name}, null))");
709+
Context.Before.WriteLineIndent(
710+
$@"throw new global::System.ArgumentNullException(""{
711+
Context.Parameter.Name}"", ""Cannot be null because it is passed by value."");");
712+
var realClass = @class.OriginalClass ?? @class;
713+
var qualifiedIdentifier = typePrinter.PrintNative(realClass);
714+
Context.Return.Write($"*({qualifiedIdentifier}*) {paramInstance}");
715+
return;
716+
}
717+
718+
Class decl;
719+
if (type.TryGetClass(out decl) && decl.IsValueType)
720+
{
721+
Context.Return.Write(paramInstance);
722+
return;
723+
}
724+
725+
if (type.IsPointer())
706726
{
707-
Class decl;
708-
if (type.TryGetClass(out decl) && decl.IsValueType)
727+
if (Context.Parameter.IsIndirect)
728+
{
729+
Context.Before.WriteLine($"if (ReferenceEquals({Context.Parameter.Name}, null))");
730+
Context.Before.WriteLineIndent(
731+
$@"throw new global::System.ArgumentNullException(""{
732+
Context.Parameter.Name}"", ""Cannot be null because it is passed by value."");");
709733
Context.Return.Write(paramInstance);
734+
}
710735
else
711736
{
712-
if (type.IsPointer())
713-
{
714-
Context.Return.Write("{0}{1}",
715-
method != null && method.OperatorKind == CXXOperatorKind.EqualEqual
716-
? string.Empty
717-
: $"ReferenceEquals({param}, null) ? global::System.IntPtr.Zero : ",
718-
paramInstance);
719-
}
720-
else
721-
{
722-
if (method == null ||
723-
// redundant for comparison operators, they are handled in a special way
724-
(method.OperatorKind != CXXOperatorKind.EqualEqual &&
725-
method.OperatorKind != CXXOperatorKind.ExclaimEqual))
726-
{
727-
Context.Before.WriteLine("if (ReferenceEquals({0}, null))", param);
728-
Context.Before.WriteLineIndent(
729-
"throw new global::System.ArgumentNullException(\"{0}\", " +
730-
"\"Cannot be null because it is a C++ reference (&).\");",
731-
param);
732-
}
733-
Context.Return.Write(paramInstance);
734-
}
737+
Context.Return.Write("{0}{1}",
738+
method != null && method.OperatorKind == CXXOperatorKind.EqualEqual
739+
? string.Empty
740+
: $"ReferenceEquals({param}, null) ? global::System.IntPtr.Zero : ",
741+
paramInstance);
735742
}
736743
return;
737744
}
738745

739-
var realClass = @class.OriginalClass ?? @class;
740-
var qualifiedIdentifier = typePrinter.PrintNative(realClass);
741-
Context.Return.Write(
742-
"ReferenceEquals({0}, null) ? new {1}() : *({1}*) {2}",
743-
param, qualifiedIdentifier, paramInstance);
746+
if (method == null ||
747+
// redundant for comparison operators, they are handled in a special way
748+
(method.OperatorKind != CXXOperatorKind.EqualEqual &&
749+
method.OperatorKind != CXXOperatorKind.ExclaimEqual))
750+
{
751+
Context.Before.WriteLine($"if (ReferenceEquals({Context.Parameter.Name}, null))");
752+
Context.Before.WriteLineIndent(
753+
$@"throw new global::System.ArgumentNullException(""{
754+
Context.Parameter.Name}"", ""Cannot be null because it is a C++ reference (&)."");",
755+
param);
756+
}
757+
758+
Context.Return.Write(paramInstance);
744759
}
745760

746761
private void MarshalValueClass()

Diff for: src/Generator/Passes/CheckAbiParameters.cs

+1-3
Original file line numberDiff line numberDiff line change
@@ -81,9 +81,7 @@ public override bool VisitFunctionDecl(Function function)
8181
});
8282
}
8383

84-
foreach (var param in from p in function.Parameters
85-
where p.IsIndirect && !p.Type.Desugar().IsAddress()
86-
select p)
84+
foreach (var param in function.Parameters.Where(p => p.IsIndirect))
8785
{
8886
param.QualifiedType = new QualifiedType(new PointerType(param.QualifiedType));
8987
}

Diff for: tests/CSharp/CSharp.Tests.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,7 @@ public void TestDefaultArguments()
248248
methodsWithDefaultValues.DefaultEmptyEnum();
249249
methodsWithDefaultValues.DefaultRefTypeBeforeOthers();
250250
methodsWithDefaultValues.DefaultRefTypeAfterOthers();
251-
methodsWithDefaultValues.DefaultRefTypeBeforeAndAfterOthers(0, null);
251+
methodsWithDefaultValues.DefaultRefTypeBeforeAndAfterOthers();
252252
methodsWithDefaultValues.DefaultIntAssignedAnEnum();
253253
methodsWithDefaultValues.defaultRefAssignedValue();
254254
methodsWithDefaultValues.DefaultRefAssignedValue();

Diff for: tests/Common/Common.Tests.cs

+6
Original file line numberDiff line numberDiff line change
@@ -772,6 +772,12 @@ public void TestPassingNullToRef()
772772
}
773773
}
774774

775+
[Test]
776+
public void TestPassingNullToValue()
777+
{
778+
Assert.Catch<ArgumentNullException>(() => new Bar((Foo) null));
779+
}
780+
775781
[Test]
776782
public void TestNonTrivialDtorInvocation()
777783
{

0 commit comments

Comments
 (0)