Skip to content

Commit 3aaefdd

Browse files
committed
Introduce IDisposeScopeClient
1 parent b545520 commit 3aaefdd

File tree

9 files changed

+243
-179
lines changed

9 files changed

+243
-179
lines changed

src/TorchSharp/DataLoader.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -388,7 +388,7 @@ sealed class DataLoaderEnumerator : IEnumerator<S>
388388
{
389389
private readonly DataLoader<T, S> loader;
390390
private IEnumerator<long> shuffler;
391-
private HashSet<IDisposable>? currentDisposables;
391+
private HashSet<IDisposeScopeClient>? currentDisposables;
392392
public DataLoaderEnumerator(DataLoader<T, S> loader)
393393
{
394394
this.loader = loader;
@@ -397,7 +397,7 @@ public DataLoaderEnumerator(DataLoader<T, S> loader)
397397
Reset();
398398
}
399399

400-
private static void DisposeAll(HashSet<IDisposable> disposables)
400+
private static void DisposeAll(HashSet<IDisposeScopeClient> disposables)
401401
{
402402
foreach (var disposable in disposables) {
403403
disposable.Dispose();
@@ -426,7 +426,7 @@ public bool MoveNext()
426426
}
427427

428428
var tensors = new T[indices.Length];
429-
var getTensorDisposables = new HashSet<IDisposable>[indices.Length];
429+
var getTensorDisposables = new HashSet<IDisposeScopeClient>[indices.Length];
430430
Enumerable.Range(0, indices.Length)
431431
.AsParallel()
432432
.WithDegreeOfParallelism(loader.num_workers)

src/TorchSharp/DisposeScope.cs

Lines changed: 74 additions & 113 deletions
Large diffs are not rendered by default.

src/TorchSharp/DisposeScopeManager.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,15 @@ public class DisposeScopeManager
1818
internal ThreadDisposeScopeStatistics StatisticsInstance { get; } = new ThreadDisposeScopeStatistics();
1919
internal DisposeScope? CurrentDisposeScope { get; private set; } = null;
2020

21-
internal DisposeScope? RegisterOnCurrentDisposeScope(IDisposable tensor)
21+
internal DisposeScope? RegisterOnCurrentDisposeScope(IDisposeScopeClient client)
2222
{
2323
if (this.CurrentDisposeScope is null) {
2424
StatisticsInstance.CreatedOutsideScopeCount++;
2525
return null;
2626
}
2727

2828
StatisticsInstance.CreatedInScopeCount++;
29-
this.CurrentDisposeScope.Disposables.Add(tensor);
29+
this.CurrentDisposeScope.Disposables.Add(client);
3030
return CurrentDisposeScope;
3131
}
3232

src/TorchSharp/IDisposeScopeClient.cs

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
#nullable enable
2+
using System;
3+
4+
namespace TorchSharp
5+
{
6+
/// <summary>
7+
/// Represents any object managed by the DisposeScope system. You must invoke Attach on DisposeScope manually
8+
/// to begin participating. Recommended implementation is to pass a DisposeScope to your object's
9+
/// constructor and invoke scope.Attach during creation.
10+
/// </summary>
11+
public interface IDisposeScopeClient: IDisposable
12+
{
13+
/// <summary>
14+
/// The DisposeScope that currently owns this object. Do not modify this property
15+
/// directly, it is managed by the scope system
16+
/// </summary>
17+
public DisposeScope? OwningDisposeScope { get; set; }
18+
/// <summary>
19+
/// Is true if the object has been disposed, false otherwise.
20+
/// </summary>
21+
public bool IsInvalid { get; }
22+
}
23+
}

src/TorchSharp/NN/Utils/PackedSequence.cs

Lines changed: 27 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
using System.Runtime.InteropServices;
55
using static TorchSharp.PInvoke.NativeMethods;
66

7+
#nullable enable
78
namespace TorchSharp
89
{
910
public static partial class torch
@@ -17,9 +18,9 @@ public static partial class rnn
1718
/// <summary>
1819
/// A packed batch of variable length sequences.
1920
/// </summary>
20-
public sealed class PackedSequence : IDisposable
21+
public sealed class PackedSequence : IDisposeScopeClient
2122
{
22-
internal DisposeScope OwningDisposeScope { get; set; }
23+
public DisposeScope? OwningDisposeScope { get; set; }
2324

2425
/// <summary>
2526
/// Class wrapping PyTorch's packedsequence object reference.
@@ -79,10 +80,7 @@ internal PackedSequence(HType handle)
7980
this.batch_sizes = new Tensor(THSNN_PackedSequence_batch_sizes(handle));
8081
this.sorted_indices = new Tensor(THSNN_PackedSequence_sorted_indices(handle));
8182
this.unsorted_indices = new Tensor(THSNN_PackedSequence_unsorted_indices(handle));
82-
OwningDisposeScope = DisposeScopeManager.ThreadSingleton.RegisterOnCurrentDisposeScope(this.data);
83-
OwningDisposeScope = DisposeScopeManager.ThreadSingleton.RegisterOnCurrentDisposeScope(this.batch_sizes);
84-
OwningDisposeScope = DisposeScopeManager.ThreadSingleton.RegisterOnCurrentDisposeScope(this.sorted_indices);
85-
OwningDisposeScope = DisposeScopeManager.ThreadSingleton.RegisterOnCurrentDisposeScope(this.unsorted_indices);
83+
8684
OwningDisposeScope = DisposeScopeManager.ThreadSingleton.RegisterOnCurrentDisposeScope(this);
8785
}
8886

@@ -95,14 +93,15 @@ public void Dispose()
9593
{
9694
this.data.Dispose();
9795
this.batch_sizes.Dispose();
98-
this.sorted_indices.Dispose();
99-
this.unsorted_indices.Dispose();
96+
if (!this.sorted_indices.IsInvalid) {
97+
this.sorted_indices.Dispose();
98+
this.unsorted_indices.Dispose();
99+
}
100100
OwningDisposeScope?.MarkAsDisposed(this);
101101

102-
if (handle != null && !handle.IsInvalid) {
102+
if (!handle.IsInvalid) {
103103
handle.Dispose();
104104
handle.SetHandleAsInvalid();
105-
106105
}
107106
}
108107
/// <summary>
@@ -114,8 +113,10 @@ public PackedSequence MoveToOuterDisposeScope()
114113
{
115114
OwningDisposeScope?.MoveToOuter(this.data);
116115
OwningDisposeScope?.MoveToOuter(this.batch_sizes);
117-
OwningDisposeScope?.MoveToOuter(this.sorted_indices);
118-
OwningDisposeScope?.MoveToOuter(this.unsorted_indices);
116+
if (!this.sorted_indices.IsInvalid) {
117+
OwningDisposeScope?.MoveToOuter(this.sorted_indices);
118+
OwningDisposeScope?.MoveToOuter(this.unsorted_indices);
119+
}
119120
OwningDisposeScope?.MoveToOuter(this);
120121
return this;
121122
}
@@ -128,31 +129,37 @@ public PackedSequence DetachFromDisposeScope()
128129
{
129130
OwningDisposeScope?.Detach(this.data);
130131
OwningDisposeScope?.Detach(this.batch_sizes);
131-
OwningDisposeScope?.Detach(this.sorted_indices);
132-
OwningDisposeScope?.Detach(this.unsorted_indices);
132+
if (!this.sorted_indices.IsInvalid) {
133+
OwningDisposeScope?.Detach(this.sorted_indices);
134+
OwningDisposeScope?.Detach(this.unsorted_indices);
135+
}
133136
OwningDisposeScope?.Detach(this);
134137
return this;
135138
}
136139

137-
public PackedSequence MoveToOtherDisposeScope(PackedSequence other)
140+
public PackedSequence MoveToOtherDisposeScope(IDisposeScopeClient other)
138141
{
139142
return MoveToOtherDisposeScope(other.OwningDisposeScope);
140143
}
141144

142-
public PackedSequence MoveToOtherDisposeScope(DisposeScope other)
145+
public PackedSequence MoveToOtherDisposeScope(DisposeScope? other)
143146
{
144147
if (OwningDisposeScope == null && other != null) {
145148
other.Attach(this.data);
146149
other.Attach(this.batch_sizes);
147-
other.Attach(this.sorted_indices);
148-
other.Attach(this.unsorted_indices);
150+
if (!this.sorted_indices.IsInvalid) {
151+
other.Attach(this.sorted_indices);
152+
other.Attach(this.unsorted_indices);
153+
}
149154
other.Attach(this);
150155
}
151156
else {
152157
OwningDisposeScope?.MoveToOther(other, this.data);
153158
OwningDisposeScope?.MoveToOther(other, this.batch_sizes);
154-
OwningDisposeScope?.MoveToOther(other, this.sorted_indices);
155-
OwningDisposeScope?.MoveToOther(other, this.unsorted_indices);
159+
if (!this.sorted_indices.IsInvalid) {
160+
OwningDisposeScope?.MoveToOther(other, this.sorted_indices);
161+
OwningDisposeScope?.MoveToOther(other, this.unsorted_indices);
162+
}
156163
OwningDisposeScope?.MoveToOther(other, this);
157164
}
158165
return this;

src/TorchSharp/Tensor/Tensor.cs

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ public static partial class torch
2020
/// Represents a TorchSharp tensor.
2121
/// </summary>
2222
[TorchSharp.Utils.TypeFormatterSource(typeof(TorchSharp.Utils.TypeFormatterSource))]
23-
public partial class Tensor : IDisposable
23+
public partial class Tensor : IDisposeScopeClient
2424
{
2525
/// <summary>
2626
/// A handle to the underlying native tensor.
@@ -32,7 +32,7 @@ public partial class Tensor : IDisposable
3232
static long _totalCount = 0;
3333
static long _peakCount = 0;
3434

35-
internal DisposeScope? OwningDisposeScope { get; set; }
35+
public DisposeScope? OwningDisposeScope { get; set; }
3636

3737
internal Tensor(IntPtr handle)
3838
{
@@ -417,7 +417,7 @@ public void WriteBytesToStream(Stream stream, int bufferSize = 1024)
417417
_validate(0);
418418

419419
long totalSize = NumberOfElements * ElementSize;
420-
420+
421421
unsafe {
422422
var ptr = NativeMethods.THSTensor_data(handle);
423423
if (ptr == IntPtr.Zero) { CheckForErrors(); }
@@ -451,7 +451,7 @@ public void ReadBytesFromStream(Stream stream, int bufferSize = 1024)
451451
long totalSize = NumberOfElements * ElementSize;
452452

453453
// Validate that this tensor matches the conditions for reading the bytes - pass 0 as total size
454-
// since we don't need to check that condition.
454+
// since we don't need to check that condition.
455455
_validate(0);
456456

457457
unsafe {
@@ -471,7 +471,7 @@ public void ReadBytesFromStream(Stream stream, int bufferSize = 1024)
471471
// Copy the contents over to the span
472472
var span = new Span<byte>((void*)ptr, bytesRead);
473473
buffer.AsSpan(0, bytesRead).CopyTo(span);
474-
474+
475475
// Increment our pointer and decrease the total size of elements we have to write
476476
ptr += bytesRead;
477477
totalSize -= bytesRead;
@@ -793,7 +793,7 @@ public Tensor mps(bool non_blocking = false)
793793
return new Tensor(res);
794794

795795
}
796-
796+
797797
/// <summary>
798798
/// Returns a copy of this object in CUDA memory.
799799
/// If this object is already in CUDA memory and on the correct device, then no copy is performed and the original object is returned.
@@ -3277,14 +3277,14 @@ public Tensor trace()
32773277
/// <summary>
32783278
/// Creates a tensor whose diagonals of certain 2D planes (specified by dim1 and dim2) are filled by input.
32793279
/// To facilitate creating batched diagonal matrices, the 2D planes formed by the last two dimensions of the returned tensor are chosen by default.
3280-
///
3280+
///
32813281
/// The argument offset controls which diagonal to consider:
32823282
/// If offset is equal to 0, it is the main diagonal.
32833283
/// If offset is greater than 0, it is above the main diagonal.
32843284
/// If offset is less than 0, it is below the main diagonal.
3285-
///
3285+
///
32863286
/// The size of the new matrix will be calculated to make the specified diagonal of the size of the last input dimension.Note that for offset other than 0,
3287-
///
3287+
///
32883288
/// the order of dim1 and dim2 matters.Exchanging them is equivalent to changing the sign of offset.
32893289
/// </summary>
32903290
/// <param name="offset">Which diagonal to consider.</param>
@@ -3354,7 +3354,7 @@ public Tensor erf()
33543354
public Tensor erf_()
33553355
{
33563356
NativeMethods.THSTensor_erf_(Handle);
3357-
CheckForErrors();
3357+
CheckForErrors();
33583358
return this;
33593359
}
33603360

@@ -3438,6 +3438,7 @@ public Tensor eq_(Scalar target)
34383438

34393439
public bool Equals(Tensor target)
34403440
{
3441+
var t = this;
34413442
if (target is null) return false;
34423443
var res = NativeMethods.THSTensor_equal(Handle, target.Handle);
34433444
CheckForErrors();
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
using TorchSharp;
2+
using Xunit;
3+
4+
namespace TorchSharpTest;
5+
6+
[Collection("Sequential")]
7+
public class TestDisposeScopesWithPackedSequence
8+
{
9+
[Fact]
10+
public void PackSequencesMoveDisposeScope()
11+
{
12+
torch.nn.utils.rnn.PackedSequence packed_sequence;
13+
var otherScope = torch.NewDisposeScope();
14+
using (var outerScope = torch.NewDisposeScope()) {
15+
using (var innerScope = torch.NewDisposeScope()) {
16+
var sequences = make_sequence_tensors();
17+
packed_sequence = torch.nn.utils.rnn.pack_sequence(sequences, enforce_sorted: false);
18+
AssertPackedSequenceValid(packed_sequence);
19+
packed_sequence.MoveToOuterDisposeScope();
20+
}
21+
22+
AssertPackedSequenceValid(packed_sequence);
23+
packed_sequence.MoveToOtherDisposeScope(otherScope);
24+
}
25+
26+
AssertPackedSequenceValid(packed_sequence);
27+
otherScope.Dispose();
28+
Assert.True(packed_sequence.IsInvalid);
29+
Assert.True(packed_sequence.data.IsInvalid);
30+
}
31+
32+
[Fact]
33+
public void PackedSequencesWorkWhenSorted()
34+
{
35+
var sequences = make_sequence_tensors();
36+
37+
var scope = torch.NewDisposeScope();
38+
var packed_sequence = torch.nn.utils.rnn.pack_sequence(sequences, enforce_sorted: true);
39+
Assert.Equal(5, scope.DisposablesCount);
40+
Assert.False(packed_sequence.IsInvalid);
41+
Assert.False(packed_sequence.batch_sizes.IsInvalid);
42+
Assert.False(packed_sequence.data.IsInvalid);
43+
Assert.True(packed_sequence.sorted_indices.IsInvalid);
44+
Assert.True(packed_sequence.unsorted_indices.IsInvalid);
45+
46+
scope.Dispose();
47+
Assert.True(packed_sequence.IsInvalid);
48+
Assert.True(packed_sequence.batch_sizes.IsInvalid);
49+
Assert.True(packed_sequence.data.IsInvalid);
50+
Assert.True(packed_sequence.sorted_indices.IsInvalid);
51+
Assert.True(packed_sequence.unsorted_indices.IsInvalid);
52+
}
53+
54+
private static void AssertPackedSequenceValid(torch.nn.utils.rnn.PackedSequence packed_sequence)
55+
{
56+
Assert.False(packed_sequence.IsInvalid);
57+
Assert.False(packed_sequence.batch_sizes.IsInvalid);
58+
Assert.False(packed_sequence.data.IsInvalid);
59+
Assert.False(packed_sequence.sorted_indices.IsInvalid);
60+
Assert.False(packed_sequence.unsorted_indices.IsInvalid);
61+
}
62+
63+
private static torch.Tensor[] make_sequence_tensors()
64+
{
65+
var sequences =
66+
new torch.Tensor[] { torch.tensor(new long[] { 1, 2, 3, 4 }), torch.tensor(new long[] { 5, 6 }), };
67+
return sequences;
68+
}
69+
}
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
using TorchSharp;
2+
using Xunit;
3+
4+
namespace TorchSharpTest;
5+
6+
[Collection("Sequential")]
7+
public class TestDisposeScopesWithUserObject
8+
{
9+
[Fact]
10+
public void UserObjectCanParticipateInScopeSystem()
11+
{
12+
var scope = torch.NewDisposeScope();
13+
var custom = new CustomScopedObject();
14+
Assert.False(custom.IsInvalid);
15+
Assert.Equal(0, scope.DisposablesCount);
16+
17+
scope.Attach(custom);
18+
Assert.Equal(1, scope.DisposablesCount);
19+
20+
scope.Dispose();
21+
Assert.True(custom.IsInvalid);
22+
}
23+
24+
private class CustomScopedObject : IDisposeScopeClient
25+
{
26+
public void Dispose()
27+
{
28+
IsInvalid = true;
29+
}
30+
31+
public DisposeScope OwningDisposeScope { get; set; }
32+
public bool IsInvalid { get; private set; } = false;
33+
}
34+
}

0 commit comments

Comments
 (0)