Skip to content

Commit 29522e9

Browse files
committed
Revert "Introduce IDisposeScopeClient"
This reverts commit 3aaefdd.
1 parent ade95b2 commit 29522e9

9 files changed

+179
-243
lines changed

src/TorchSharp/DataLoader.cs

+3-3
Original file line numberDiff line numberDiff line change
@@ -388,7 +388,7 @@ sealed class DataLoaderEnumerator : IEnumerator<S>
388388
{
389389
private readonly DataLoader<T, S> loader;
390390
private IEnumerator<long> shuffler;
391-
private HashSet<IDisposeScopeClient>? currentDisposables;
391+
private HashSet<IDisposable>? currentDisposables;
392392
public DataLoaderEnumerator(DataLoader<T, S> loader)
393393
{
394394
this.loader = loader;
@@ -397,7 +397,7 @@ public DataLoaderEnumerator(DataLoader<T, S> loader)
397397
Reset();
398398
}
399399

400-
private static void DisposeAll(HashSet<IDisposeScopeClient> disposables)
400+
private static void DisposeAll(HashSet<IDisposable> disposables)
401401
{
402402
foreach (var disposable in disposables) {
403403
disposable.Dispose();
@@ -426,7 +426,7 @@ public bool MoveNext()
426426
}
427427

428428
var tensors = new T[indices.Length];
429-
var getTensorDisposables = new HashSet<IDisposeScopeClient>[indices.Length];
429+
var getTensorDisposables = new HashSet<IDisposable>[indices.Length];
430430
Enumerable.Range(0, indices.Length)
431431
.AsParallel()
432432
.WithDegreeOfParallelism(loader.num_workers)

src/TorchSharp/DisposeScope.cs

+113-74
Large diffs are not rendered by default.

src/TorchSharp/DisposeScopeManager.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,15 @@ public class DisposeScopeManager
1818
internal ThreadDisposeScopeStatistics StatisticsInstance { get; } = new ThreadDisposeScopeStatistics();
1919
internal DisposeScope? CurrentDisposeScope { get; private set; } = null;
2020

21-
internal DisposeScope? RegisterOnCurrentDisposeScope(IDisposeScopeClient client)
21+
internal DisposeScope? RegisterOnCurrentDisposeScope(IDisposable tensor)
2222
{
2323
if (this.CurrentDisposeScope is null) {
2424
StatisticsInstance.CreatedOutsideScopeCount++;
2525
return null;
2626
}
2727

2828
StatisticsInstance.CreatedInScopeCount++;
29-
this.CurrentDisposeScope.Disposables.Add(client);
29+
this.CurrentDisposeScope.Disposables.Add(tensor);
3030
return CurrentDisposeScope;
3131
}
3232

src/TorchSharp/IDisposeScopeClient.cs

-23
This file was deleted.

src/TorchSharp/NN/Utils/PackedSequence.cs

+20-27
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
using System.Runtime.InteropServices;
55
using static TorchSharp.PInvoke.NativeMethods;
66

7-
#nullable enable
87
namespace TorchSharp
98
{
109
public static partial class torch
@@ -18,9 +17,9 @@ public static partial class rnn
1817
/// <summary>
1918
/// A packed batch of variable length sequences.
2019
/// </summary>
21-
public sealed class PackedSequence : IDisposeScopeClient
20+
public sealed class PackedSequence : IDisposable
2221
{
23-
public DisposeScope? OwningDisposeScope { get; set; }
22+
internal DisposeScope OwningDisposeScope { get; set; }
2423

2524
/// <summary>
2625
/// Class wrapping PyTorch's packedsequence object reference.
@@ -80,7 +79,10 @@ internal PackedSequence(HType handle)
8079
this.batch_sizes = new Tensor(THSNN_PackedSequence_batch_sizes(handle));
8180
this.sorted_indices = new Tensor(THSNN_PackedSequence_sorted_indices(handle));
8281
this.unsorted_indices = new Tensor(THSNN_PackedSequence_unsorted_indices(handle));
83-
82+
OwningDisposeScope = DisposeScopeManager.ThreadSingleton.RegisterOnCurrentDisposeScope(this.data);
83+
OwningDisposeScope = DisposeScopeManager.ThreadSingleton.RegisterOnCurrentDisposeScope(this.batch_sizes);
84+
OwningDisposeScope = DisposeScopeManager.ThreadSingleton.RegisterOnCurrentDisposeScope(this.sorted_indices);
85+
OwningDisposeScope = DisposeScopeManager.ThreadSingleton.RegisterOnCurrentDisposeScope(this.unsorted_indices);
8486
OwningDisposeScope = DisposeScopeManager.ThreadSingleton.RegisterOnCurrentDisposeScope(this);
8587
}
8688

@@ -93,15 +95,14 @@ public void Dispose()
9395
{
9496
this.data.Dispose();
9597
this.batch_sizes.Dispose();
96-
if (!this.sorted_indices.IsInvalid) {
97-
this.sorted_indices.Dispose();
98-
this.unsorted_indices.Dispose();
99-
}
98+
this.sorted_indices.Dispose();
99+
this.unsorted_indices.Dispose();
100100
OwningDisposeScope?.MarkAsDisposed(this);
101101

102-
if (!handle.IsInvalid) {
102+
if (handle != null && !handle.IsInvalid) {
103103
handle.Dispose();
104104
handle.SetHandleAsInvalid();
105+
105106
}
106107
}
107108
/// <summary>
@@ -113,10 +114,8 @@ public PackedSequence MoveToOuterDisposeScope()
113114
{
114115
OwningDisposeScope?.MoveToOuter(this.data);
115116
OwningDisposeScope?.MoveToOuter(this.batch_sizes);
116-
if (!this.sorted_indices.IsInvalid) {
117-
OwningDisposeScope?.MoveToOuter(this.sorted_indices);
118-
OwningDisposeScope?.MoveToOuter(this.unsorted_indices);
119-
}
117+
OwningDisposeScope?.MoveToOuter(this.sorted_indices);
118+
OwningDisposeScope?.MoveToOuter(this.unsorted_indices);
120119
OwningDisposeScope?.MoveToOuter(this);
121120
return this;
122121
}
@@ -129,37 +128,31 @@ public PackedSequence DetachFromDisposeScope()
129128
{
130129
OwningDisposeScope?.Detach(this.data);
131130
OwningDisposeScope?.Detach(this.batch_sizes);
132-
if (!this.sorted_indices.IsInvalid) {
133-
OwningDisposeScope?.Detach(this.sorted_indices);
134-
OwningDisposeScope?.Detach(this.unsorted_indices);
135-
}
131+
OwningDisposeScope?.Detach(this.sorted_indices);
132+
OwningDisposeScope?.Detach(this.unsorted_indices);
136133
OwningDisposeScope?.Detach(this);
137134
return this;
138135
}
139136

140-
public PackedSequence MoveToOtherDisposeScope(IDisposeScopeClient other)
137+
public PackedSequence MoveToOtherDisposeScope(PackedSequence other)
141138
{
142139
return MoveToOtherDisposeScope(other.OwningDisposeScope);
143140
}
144141

145-
public PackedSequence MoveToOtherDisposeScope(DisposeScope? other)
142+
public PackedSequence MoveToOtherDisposeScope(DisposeScope other)
146143
{
147144
if (OwningDisposeScope == null && other != null) {
148145
other.Attach(this.data);
149146
other.Attach(this.batch_sizes);
150-
if (!this.sorted_indices.IsInvalid) {
151-
other.Attach(this.sorted_indices);
152-
other.Attach(this.unsorted_indices);
153-
}
147+
other.Attach(this.sorted_indices);
148+
other.Attach(this.unsorted_indices);
154149
other.Attach(this);
155150
}
156151
else {
157152
OwningDisposeScope?.MoveToOther(other, this.data);
158153
OwningDisposeScope?.MoveToOther(other, this.batch_sizes);
159-
if (!this.sorted_indices.IsInvalid) {
160-
OwningDisposeScope?.MoveToOther(other, this.sorted_indices);
161-
OwningDisposeScope?.MoveToOther(other, this.unsorted_indices);
162-
}
154+
OwningDisposeScope?.MoveToOther(other, this.sorted_indices);
155+
OwningDisposeScope?.MoveToOther(other, this.unsorted_indices);
163156
OwningDisposeScope?.MoveToOther(other, this);
164157
}
165158
return this;

src/TorchSharp/Tensor/Tensor.cs

+10-11
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ public static partial class torch
2020
/// Represents a TorchSharp tensor.
2121
/// </summary>
2222
[TorchSharp.Utils.TypeFormatterSource(typeof(TorchSharp.Utils.TypeFormatterSource))]
23-
public partial class Tensor : IDisposeScopeClient
23+
public partial class Tensor : IDisposable
2424
{
2525
/// <summary>
2626
/// A handle to the underlying native tensor.
@@ -32,7 +32,7 @@ public partial class Tensor : IDisposeScopeClient
3232
static long _totalCount = 0;
3333
static long _peakCount = 0;
3434

35-
public DisposeScope? OwningDisposeScope { get; set; }
35+
internal DisposeScope? OwningDisposeScope { get; set; }
3636

3737
internal Tensor(IntPtr handle)
3838
{
@@ -417,7 +417,7 @@ public void WriteBytesToStream(Stream stream, int bufferSize = 1024)
417417
_validate(0);
418418

419419
long totalSize = NumberOfElements * ElementSize;
420-
420+
421421
unsafe {
422422
var ptr = NativeMethods.THSTensor_data(handle);
423423
if (ptr == IntPtr.Zero) { CheckForErrors(); }
@@ -451,7 +451,7 @@ public void ReadBytesFromStream(Stream stream, int bufferSize = 1024)
451451
long totalSize = NumberOfElements * ElementSize;
452452

453453
// Validate that this tensor matches the conditions for reading the bytes - pass 0 as total size
454-
// since we don't need to check that condition.
454+
// since we don't need to check that condition.
455455
_validate(0);
456456

457457
unsafe {
@@ -471,7 +471,7 @@ public void ReadBytesFromStream(Stream stream, int bufferSize = 1024)
471471
// Copy the contents over to the span
472472
var span = new Span<byte>((void*)ptr, bytesRead);
473473
buffer.AsSpan(0, bytesRead).CopyTo(span);
474-
474+
475475
// Increment our pointer and decrease the total size of elements we have to write
476476
ptr += bytesRead;
477477
totalSize -= bytesRead;
@@ -793,7 +793,7 @@ public Tensor mps(bool non_blocking = false)
793793
return new Tensor(res);
794794

795795
}
796-
796+
797797
/// <summary>
798798
/// Returns a copy of this object in CUDA memory.
799799
/// If this object is already in CUDA memory and on the correct device, then no copy is performed and the original object is returned.
@@ -3277,14 +3277,14 @@ public Tensor trace()
32773277
/// <summary>
32783278
/// Creates a tensor whose diagonals of certain 2D planes (specified by dim1 and dim2) are filled by input.
32793279
/// To facilitate creating batched diagonal matrices, the 2D planes formed by the last two dimensions of the returned tensor are chosen by default.
3280-
///
3280+
///
32813281
/// The argument offset controls which diagonal to consider:
32823282
/// If offset is equal to 0, it is the main diagonal.
32833283
/// If offset is greater than 0, it is above the main diagonal.
32843284
/// If offset is less than 0, it is below the main diagonal.
3285-
///
3285+
///
32863286
/// The size of the new matrix will be calculated to make the specified diagonal of the size of the last input dimension.Note that for offset other than 0,
3287-
///
3287+
///
32883288
/// the order of dim1 and dim2 matters.Exchanging them is equivalent to changing the sign of offset.
32893289
/// </summary>
32903290
/// <param name="offset">Which diagonal to consider.</param>
@@ -3354,7 +3354,7 @@ public Tensor erf()
33543354
public Tensor erf_()
33553355
{
33563356
NativeMethods.THSTensor_erf_(Handle);
3357-
CheckForErrors();
3357+
CheckForErrors();
33583358
return this;
33593359
}
33603360

@@ -3438,7 +3438,6 @@ public Tensor eq_(Scalar target)
34383438

34393439
public bool Equals(Tensor target)
34403440
{
3441-
var t = this;
34423441
if (target is null) return false;
34433442
var res = NativeMethods.THSTensor_equal(Handle, target.Handle);
34443443
CheckForErrors();

test/TorchSharpTest/TestDisposeScopesWithPackedSequence.cs

-69
This file was deleted.

test/TorchSharpTest/TestDisposeScopesWithUserObject.cs

-34
This file was deleted.

0 commit comments

Comments
 (0)