Skip to content

Commit 9856487

Browse files
Merge pull request #1393 from mvphelps/ToTensorMemoryLeakFix
Fix memory leaks in .ToTensor, resolves #1392
2 parents d545f3d + ce18679 commit 9856487

9 files changed

+85
-91
lines changed

RELEASENOTES.md

+2-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@ __Bug Fixes__:
1212
#1383 `torch.linalg.vector_norm`: Make `ord`-argument optional, as specified in docs<br/>
1313
#1385 PackedSequence now participates in the DisposeScope system at the same level as Tensor objects.<br/>
1414
#1387 Attaching tensor to a DisposeScope no longer makes Statistics.DetachedFromScopeCount go negative.<br/>
15-
#1390 DisposeScopeManager.Statistics now includes DisposedOutsideScopeCount and AttachedToScopeCount. ThreadTotalLiveCount is now exact instead of approximate. ToString gives a useful debug string, and documentation is added for how to troubleshoot memory leaks. Also DisposeScopeManager.Statistics.TensorStatistics and DisposeScopeManager.Statistics.PackedSequenceStatistics provide separate metrics for these objects.
15+
#1390 DisposeScopeManager.Statistics now includes DisposedOutsideScopeCount and AttachedToScopeCount. ThreadTotalLiveCount is now exact instead of approximate. ToString gives a useful debug string, and documentation is added for how to troubleshoot memory leaks. Also DisposeScopeManager.Statistics.TensorStatistics and DisposeScopeManager.Statistics.PackedSequenceStatistics provide separate metrics for these objects.<br/>
16+
#1392 ToTensor() extension method memory leaks fixed.<br/>
1617

1718
# NuGet Version 0.103.0
1819

src/TorchSharp/Tensor/Factories/tensor_Complex.cs

+1-3
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,7 @@ public static Tensor tensor(System.Numerics.Complex scalar, ScalarType? dtype =
1818
device = InitializeDevice(device);
1919
var handle = THSTensor_newComplexFloat64Scalar(scalar.Real, scalar.Imaginary, (int)device.type, device.index, requires_grad);
2020
if (handle == IntPtr.Zero) { CheckForErrors(); }
21-
var tensor = new Tensor(handle);
22-
tensor = dtype.HasValue ? tensor.to(dtype.Value, device) : tensor.to(device);
23-
return tensor;
21+
return InstantiateTensorWithLeakSafeTypeChange(handle, dtype);
2422
}
2523

2624
/// <summary>

src/TorchSharp/Tensor/Factories/tensor_double.cs

+6-22
Original file line numberDiff line numberDiff line change
@@ -18,43 +18,27 @@ public static Tensor tensor(double scalar, ScalarType? dtype = null, Device? dev
1818
device = InitializeDevice(device);
1919
var handle = THSTensor_newFloat64Scalar(scalar, (int)device.type, device.index, requires_grad);
2020
if (handle == IntPtr.Zero) { CheckForErrors(); }
21-
var tensor = new Tensor(handle);
22-
tensor = dtype.HasValue ? tensor.to(dtype.Value, device) : tensor.to(device);
23-
return tensor;
21+
return InstantiateTensorWithLeakSafeTypeChange(handle, dtype);
2422
}
2523

2624
/// <summary>
27-
/// Create a scalar tensor from a single value
25+
/// Create a scalar complex number tensor from a tuple of (real, imaginary)
2826
/// </summary>
2927
public static Tensor tensor((double Real, double Imaginary) scalar, ScalarType? dtype = null, Device? device = null, bool requires_grad = false)
3028
{
31-
device = InitializeDevice(device);
32-
var handle = THSTensor_newComplexFloat64Scalar(scalar.Real, scalar.Imaginary, (int)device.type, device.index, requires_grad);
33-
if (handle == IntPtr.Zero) { CheckForErrors(); }
34-
var tensor = new Tensor(handle);
35-
if (device is { }) {
36-
tensor = dtype.HasValue ? tensor.to(dtype.Value, device) : tensor.to(device);
37-
} else if (dtype.HasValue) {
38-
tensor = tensor.to_type(dtype.Value);
39-
}
40-
return tensor;
29+
return tensor(scalar.Real, scalar.Imaginary, dtype, device, requires_grad);
4130
}
4231

4332
/// <summary>
44-
/// Create a scalar tensor from a single value
33+
/// Create a scalar complex number tensor from independent real and imaginary components
4534
/// </summary>
4635
public static Tensor tensor(double real, double imaginary, ScalarType? dtype = null, Device? device = null, bool requires_grad = false)
4736
{
37+
4838
device = InitializeDevice(device);
4939
var handle = THSTensor_newComplexFloat64Scalar(real, imaginary, (int)device.type, device.index, requires_grad);
5040
if (handle == IntPtr.Zero) { CheckForErrors(); }
51-
var tensor = new Tensor(handle);
52-
if (device is { }) {
53-
tensor = dtype.HasValue ? tensor.to(dtype.Value, device) : tensor.to(device);
54-
} else if (dtype.HasValue) {
55-
tensor = tensor.to_type(dtype.Value);
56-
}
57-
return tensor;
41+
return InstantiateTensorWithLeakSafeTypeChange(handle, dtype);
5842
}
5943

6044
/// <summary>

src/TorchSharp/Tensor/Factories/tensor_float.cs

+4-19
Original file line numberDiff line numberDiff line change
@@ -22,37 +22,22 @@ public static Tensor tensor(float scalar, Device? device = null, bool requires_g
2222
}
2323

2424
/// <summary>
25-
/// Create a scalar tensor from a single value
25+
/// Create a scalar complex number tensor from independent real and imaginary components
2626
/// </summary>
2727
public static Tensor tensor(float real, float imaginary, ScalarType? dtype = null, Device? device = null, bool requires_grad = false)
2828
{
2929
device = InitializeDevice(device);
3030
var handle = THSTensor_newComplexFloat32Scalar(real, imaginary, (int)device.type, device.index, requires_grad);
3131
if (handle == IntPtr.Zero) { CheckForErrors(); }
32-
var tensor = new Tensor(handle);
33-
if (device is { }) {
34-
tensor = dtype.HasValue ? tensor.to(dtype.Value, device) : tensor.to(device);
35-
} else if (dtype.HasValue) {
36-
tensor = tensor.to_type(dtype.Value);
37-
}
38-
return tensor;
32+
return InstantiateTensorWithLeakSafeTypeChange(handle, dtype);
3933
}
4034

4135
/// <summary>
42-
/// Create a scalar tensor from a single value
36+
/// Create a scalar complex number tensor from a tuple of (real, imaginary)
4337
/// </summary>
4438
public static Tensor tensor((float Real, float Imaginary) scalar, ScalarType? dtype = null, Device? device = null, bool requires_grad = false)
4539
{
46-
device = InitializeDevice(device);
47-
var handle = THSTensor_newComplexFloat32Scalar(scalar.Real, scalar.Imaginary, (int)device.type, device.index, requires_grad);
48-
if (handle == IntPtr.Zero) { CheckForErrors(); }
49-
var tensor = new Tensor(handle);
50-
if (device is { }) {
51-
tensor = dtype.HasValue ? tensor.to(dtype.Value, device) : tensor.to(device);
52-
} else if (dtype.HasValue) {
53-
tensor = tensor.to_type(dtype.Value);
54-
}
55-
return tensor;
40+
return tensor(scalar.Real, scalar.Imaginary, dtype: dtype, device: device);
5641
}
5742

5843
/// <summary>

src/TorchSharp/Tensor/Factories/tensor_long.cs

+1-7
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,7 @@ public static Tensor tensor(long scalar, ScalarType? dtype = null, Device? devic
1818
device = InitializeDevice(device);
1919
var handle = THSTensor_newInt64Scalar(scalar, (int)device.type, device.index, requires_grad);
2020
if (handle == IntPtr.Zero) { CheckForErrors(); }
21-
var tensor = new Tensor(handle);
22-
if (device is { }) {
23-
tensor = dtype.HasValue ? tensor.to(dtype.Value, device) : tensor.to(device);
24-
} else if (dtype.HasValue) {
25-
tensor = tensor.to_type(dtype.Value);
26-
}
27-
return tensor;
21+
return InstantiateTensorWithLeakSafeTypeChange(handle, dtype);
2822
}
2923

3024
/// <summary>

src/TorchSharp/Tensor/Tensor.cs

+10
Original file line numberDiff line numberDiff line change
@@ -7402,5 +7402,15 @@ public static Tensor WrappedTensorDisposeScope(Func<Tensor> expr)
74027402
var result = expr();
74037403
return result.MoveToOuterDisposeScope();
74047404
}
7405+
internal static Tensor InstantiateTensorWithLeakSafeTypeChange(IntPtr handle, ScalarType? dtype)
7406+
{
7407+
var tensor = new Tensor(handle);
7408+
if (dtype.HasValue && tensor.dtype != dtype.Value) {
7409+
var typed = tensor.to_type(dtype.Value);
7410+
tensor.Dispose();
7411+
return typed;
7412+
}
7413+
return tensor;
7414+
}
74057415
}
74067416
}

test/TorchSharpTest/TestDisposeScopesStatisticsTensor.cs

-21
Original file line numberDiff line numberDiff line change
@@ -73,26 +73,5 @@ public void DisposingScopeAfterDetachingDoesNothing()
7373
AssertTensorCounts(0, 0, 1, 0, 0, 1, 1);
7474
AssertTotalsCounts(0, 0, 1, 0, 0, 1, 1);
7575
}
76-
77-
[Fact]
78-
public void ToTensorCreatesOrphanedTensorButDisposeScopeCleansItUp()
79-
{
80-
//Defect: This needs fixing but is unrelated to the commit that discovered
81-
//it - adding better lifetime statistics. ToTensor() leaks 1 tensor
82-
//every time it is called.
83-
var scope = torch.NewDisposeScope();
84-
var stats = DisposeScopeManager.Statistics;
85-
stats.Reset();
86-
var a1 = 1.ToTensor();
87-
Assert.Equal(2, stats.CreatedInScopeCount);
88-
//Should be 1, or can remain 0 if CreatedInScope becomes 1.
89-
Assert.Equal(0, stats.DisposedInScopeCount);
90-
a1.Dispose();
91-
Assert.Equal(1, stats.DisposedInScopeCount);
92-
93-
//Should not need this if no orphan.
94-
scope.Dispose();
95-
Assert.Equal(2, stats.DisposedInScopeCount);
96-
}
9776
}
9877
}

test/TorchSharpTest/TestDisposeScopesStatisticsTensorUnscoped.cs

-18
Original file line numberDiff line numberDiff line change
@@ -114,23 +114,5 @@ public void AttachAgainIsSameAsMoveToOtherAndStatsDoNotChange()
114114
AssertTensorCounts(1, 0, 0, 0, 1, 0, 1);
115115
AssertTotalsCounts(1, 0, 0, 0, 1, 0, 1);
116116
}
117-
118-
[Fact]
119-
public void ToTensorCreatesOrphanedTensor()
120-
{
121-
//Defect: This needs fixing but is unrelated to the commit that discovered
122-
//it - adding better lifetime statistics. ToTensor() leaks 1 tensor
123-
// //every time it is called.
124-
var stats = DisposeScopeManager.Statistics;
125-
stats.Reset();
126-
var a1 = 1.ToTensor();
127-
//Should be 1 ideally. If it remains two...
128-
Assert.Equal(2, stats.CreatedOutsideScopeCount);
129-
// ... this should be 1
130-
Assert.Equal(0, stats.DisposedOutsideScopeCount);
131-
a1.Dispose();
132-
//... and this should also be 2
133-
Assert.Equal(1, stats.DisposedOutsideScopeCount);
134-
}
135117
}
136118
}

test/TorchSharpTest/TestTorchTensor.cs

+61
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
using System.Globalization;
55
using System.IO;
66
using System.Linq;
7+
using System.Numerics;
78
using System.Runtime.InteropServices;
89
using Xunit;
910
using Xunit.Sdk;
@@ -3228,6 +3229,66 @@ public void ScalarToTensor3()
32283229
Assert.Equal(-1.0, (double)tensor);
32293230
}
32303231
}
3232+
[Fact]
3233+
[TestOf(nameof(Tensor))]
3234+
public void ScalarToTensorDoesNotLeakMemory()
3235+
{
3236+
AssertTensorDoesNotLeak(()=>{
3237+
Tensor tensor = 1;
3238+
return tensor;
3239+
});
3240+
AssertTensorDoesNotLeak(() => ((byte)1).ToTensor());
3241+
AssertTensorDoesNotLeak(() => ((sbyte)-1).ToTensor());
3242+
AssertTensorDoesNotLeak(() => ((short)-1).ToTensor());
3243+
AssertTensorDoesNotLeak(() => ((long)-1).ToTensor());
3244+
AssertTensorDoesNotLeak(() => ((float)-1).ToTensor());
3245+
AssertTensorDoesNotLeak(() => ((double)-1).ToTensor());
3246+
}
3247+
3248+
[Fact]
3249+
[TestOf(nameof(Tensor))]
3250+
public void ScalarArrayToTensorDoesNotLeakMemory()
3251+
{
3252+
AssertTensorDoesNotLeak(() => (new byte[]{1}).ToTensor(new long[]{1}));
3253+
AssertTensorDoesNotLeak(() => (new sbyte[]{-1}).ToTensor(new long[]{1}));
3254+
AssertTensorDoesNotLeak(() => (new short[]{-1}).ToTensor(new long[]{1}));
3255+
AssertTensorDoesNotLeak(() => (new long[]{-1}).ToTensor(new long[]{1}));
3256+
AssertTensorDoesNotLeak(() => (new float[]{-1}).ToTensor(new long[]{1}));
3257+
AssertTensorDoesNotLeak(() => (new double[]{-1}).ToTensor(new long[]{1}));
3258+
}
3259+
3260+
[Fact]
3261+
[TestOf(nameof(Tensor))]
3262+
public void ComplexNumberOfDoubleDoesNotLeakMemory()
3263+
{
3264+
AssertTensorDoesNotLeak(() => ( torch.tensor((double)-1, (double)-2)));
3265+
AssertTensorDoesNotLeak(() => ( torch.tensor(((double)-1, (double)-2))));
3266+
}
3267+
3268+
[Fact]
3269+
[TestOf(nameof(Tensor))]
3270+
public void ComplexNumberOfFloatDoesNotLeakMemory()
3271+
{
3272+
AssertTensorDoesNotLeak(() => (torch.tensor((float)-1, (float)-2)));
3273+
AssertTensorDoesNotLeak(() => (torch.tensor(((float)-1, (float)-2))));
3274+
}
3275+
3276+
[Fact]
3277+
[TestOf(nameof(Tensor))]
3278+
public void DotNetComplexNumberDoesNotLeakMemory()
3279+
{
3280+
AssertTensorDoesNotLeak(() => (torch.tensor(new Complex(1, 2))));
3281+
}
3282+
3283+
private void AssertTensorDoesNotLeak(Func<Tensor> createTensorFunc)
3284+
{
3285+
var stats = DisposeScopeManager.Statistics.TensorStatistics;
3286+
stats.Reset();
3287+
using (Tensor tensor = createTensorFunc()) {
3288+
Assert.Equal(1, stats.ThreadTotalLiveCount);
3289+
}
3290+
Assert.Equal(0, stats.ThreadTotalLiveCount);
3291+
}
32313292

32323293
[Fact]
32333294
[TestOf(nameof(Tensor))]

0 commit comments

Comments
 (0)