Skip to content

Commit ee455bb

Browse files
committed
PackedSequence handling identical to tensor
1 parent 29522e9 commit ee455bb

File tree

4 files changed

+135
-38
lines changed

4 files changed

+135
-38
lines changed

src/TorchSharp/DisposeScope.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ public IReadOnlyList<IDisposable> Attach(IEnumerable<IDisposable> disposables)
243243
}
244244
}
245245
else if (disposable is torch.nn.utils.rnn.PackedSequence sequence) {
246-
if (sequence.OwningDisposeScope == null) {
246+
if (sequence.OwningDisposeScope == null && !sequence.IsInvalid) {
247247
_disposeScopeManager.StatisticsInstance.DetachedFromScopeCount--;
248248
}
249249
}

src/TorchSharp/NN/Utils/PackedSequence.cs

+13-6
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information.
22

33
using System;
4+
using System.Runtime.CompilerServices;
45
using System.Runtime.InteropServices;
56
using static TorchSharp.PInvoke.NativeMethods;
67

@@ -19,7 +20,16 @@ public static partial class rnn
1920
/// </summary>
2021
public sealed class PackedSequence : IDisposable
2122
{
22-
internal DisposeScope OwningDisposeScope { get; set; }
23+
internal DisposeScope OwningDisposeScope {
24+
get => mOwningDisposeScope;
25+
set {
26+
mOwningDisposeScope = value;
27+
this.batch_sizes.OwningDisposeScope = value;
28+
this.data.OwningDisposeScope = value;
29+
this.sorted_indices.OwningDisposeScope = value;
30+
this.unsorted_indices.OwningDisposeScope = value;
31+
}
32+
}
2333

2434
/// <summary>
2535
/// Class wrapping PyTorch's packedsequence object reference.
@@ -69,8 +79,9 @@ protected override bool ReleaseHandle()
6979
/// <summary>
7080
/// Is true if the PackedSequence has been disposed, false otherwise.
7181
/// </summary>
72-
public bool IsInvalid => handle.IsInvalid;
82+
internal bool IsInvalid => handle.IsInvalid;
7383
private HType handle;
84+
private DisposeScope mOwningDisposeScope;
7485

7586
internal PackedSequence(HType handle)
7687
{
@@ -79,10 +90,6 @@ internal PackedSequence(HType handle)
7990
this.batch_sizes = new Tensor(THSNN_PackedSequence_batch_sizes(handle));
8091
this.sorted_indices = new Tensor(THSNN_PackedSequence_sorted_indices(handle));
8192
this.unsorted_indices = new Tensor(THSNN_PackedSequence_unsorted_indices(handle));
82-
OwningDisposeScope = DisposeScopeManager.ThreadSingleton.RegisterOnCurrentDisposeScope(this.data);
83-
OwningDisposeScope = DisposeScopeManager.ThreadSingleton.RegisterOnCurrentDisposeScope(this.batch_sizes);
84-
OwningDisposeScope = DisposeScopeManager.ThreadSingleton.RegisterOnCurrentDisposeScope(this.sorted_indices);
85-
OwningDisposeScope = DisposeScopeManager.ThreadSingleton.RegisterOnCurrentDisposeScope(this.unsorted_indices);
8693
OwningDisposeScope = DisposeScopeManager.ThreadSingleton.RegisterOnCurrentDisposeScope(this);
8794
}
8895

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
using System.Reflection;
2+
using TorchSharp;
3+
using Xunit;
4+
5+
namespace TorchSharpTest;
6+
7+
public class TestDisposeScopesPackedSequence
8+
{
9+
[Fact]
10+
public void MoveDisposeScope()
11+
{
12+
var sequences = CreateTestSequences();
13+
torch.nn.utils.rnn.PackedSequence packed_sequence;
14+
var otherScope = torch.NewDisposeScope();
15+
using (torch.NewDisposeScope())
16+
{
17+
using (torch.NewDisposeScope())
18+
{
19+
packed_sequence = torch.nn.utils.rnn.pack_sequence(sequences, enforce_sorted: false);
20+
AssertPackedSequenceValid(packed_sequence);
21+
22+
packed_sequence.MoveToOuterDisposeScope();
23+
}
24+
AssertPackedSequenceValid(packed_sequence);
25+
26+
packed_sequence.MoveToOtherDisposeScope(otherScope);
27+
}
28+
29+
AssertPackedSequenceValid(packed_sequence);
30+
otherScope.Dispose();
31+
32+
Assert.True(GetPackedSequenceIsInvalid(packed_sequence));
33+
Assert.True(packed_sequence.data.IsInvalid);
34+
}
35+
36+
[Fact]
37+
public void DisposablesValidityWhenNotSorted()
38+
{
39+
var sequences = CreateTestSequences();
40+
using var scope = torch.NewDisposeScope();
41+
var packed = torch.nn.utils.rnn.pack_sequence(sequences, enforce_sorted: false);
42+
Assert.Equal(5, scope.DisposablesCount);
43+
AssertPackedSequenceValid(packed);
44+
}
45+
46+
[Fact]
47+
public void DisposablesValidityWhenSorted()
48+
{
49+
var sequences = CreateTestSequences();
50+
using var scope = torch.NewDisposeScope();
51+
var packed = torch.nn.utils.rnn.pack_sequence(sequences, enforce_sorted: true);
52+
Assert.Equal(5, scope.DisposablesCount);
53+
Assert.False(GetPackedSequenceIsInvalid(packed));
54+
Assert.False(packed.batch_sizes.IsInvalid);
55+
Assert.False(packed.data.IsInvalid);
56+
Assert.True(packed.sorted_indices.IsInvalid);
57+
Assert.True(packed.unsorted_indices.IsInvalid);
58+
}
59+
60+
[Fact]
61+
public void DisposeScopeStatistics()
62+
{
63+
DisposeScopeManager.Statistics.Reset();
64+
AssertStatCounts(0, 0, 0, 0, 0);
65+
var sequences = CreateTestSequences();
66+
AssertStatCounts(0, 2, 0, 0, 0);
67+
var outOfScope = torch.nn.utils.rnn.pack_sequence(sequences, enforce_sorted: true);
68+
AssertStatCounts(0, 7, 0, 0, 0);
69+
using var scope = torch.NewDisposeScope();
70+
AssertStatCounts(0, 7, 0, 0, 0);
71+
72+
var inScope = torch.nn.utils.rnn.pack_sequence(sequences, enforce_sorted: true);
73+
AssertStatCounts(5, 7, 0, 0, 5);
74+
75+
scope.Attach(outOfScope);
76+
//Possible subtle bug. When attaching an object that isn't owned by any scope, the count goes negative.
77+
AssertStatCounts( 5, 7, -1, 0, 6);
78+
79+
scope.Detach(inScope);
80+
AssertStatCounts( 5, 7, 0, 0, 5);
81+
82+
outOfScope.Dispose();
83+
AssertStatCounts( 5, 7, 0, 5, 0);
84+
85+
}
86+
87+
private static void AssertStatCounts(long createdInScope, long createdOutsideScope, long detachedFrom, long disposedIn, long threadTotalLive)
88+
{
89+
var stats = DisposeScopeManager.Statistics;
90+
Assert.Equal(createdInScope, stats.CreatedInScopeCount);
91+
Assert.Equal(createdOutsideScope, stats.CreatedOutsideScopeCount);
92+
Assert.Equal(detachedFrom, stats.DetachedFromScopeCount);
93+
Assert.Equal(disposedIn, stats.DisposedInScopeCount);
94+
Assert.Equal(threadTotalLive, stats.ThreadTotalLiveCount);
95+
}
96+
97+
private static torch.Tensor[] CreateTestSequences()
98+
{
99+
return new[]
100+
{
101+
torch.tensor(new long[] { 1, 2, 3, 4 }),
102+
torch.tensor(new long[] { 5, 6 }),
103+
};
104+
}
105+
106+
private static void AssertPackedSequenceValid(torch.nn.utils.rnn.PackedSequence packed_sequence)
107+
{
108+
Assert.False(GetPackedSequenceIsInvalid(packed_sequence));
109+
Assert.False(packed_sequence.batch_sizes.IsInvalid);
110+
Assert.False(packed_sequence.data.IsInvalid);
111+
Assert.False(packed_sequence.sorted_indices.IsInvalid);
112+
Assert.False(packed_sequence.unsorted_indices.IsInvalid);
113+
}
114+
115+
private static bool GetPackedSequenceIsInvalid(torch.nn.utils.rnn.PackedSequence packed_sequence)
116+
{
117+
//HACK: reflection to avoid exposing internal method IsInvalid in API.
118+
var getter = typeof(torch.nn.utils.rnn.PackedSequence).GetProperty("IsInvalid", BindingFlags.Instance | BindingFlags.NonPublic)!;
119+
return (bool)getter.GetValue(packed_sequence)!;
120+
}
121+
}

test/TorchSharpTest/TestNNUtils.cs

-31
Original file line numberDiff line numberDiff line change
@@ -54,37 +54,6 @@ public void TestPackSequence()
5454
Assert.True(torch.max(torch.square(inverted_sequences - padded_sequences)).item<long>() == 0);
5555
}
5656

57-
[Fact]
58-
public void TestPackSequenceMoveDisposeScope()
59-
{
60-
nn.utils.rnn.PackedSequence packed_sequence;
61-
var otherScope = NewDisposeScope();
62-
using (var outerScope = NewDisposeScope())
63-
{
64-
using (var innerScope = NewDisposeScope()) {
65-
var (sequences, sequences_len) = make_test();
66-
packed_sequence = torch.nn.utils.rnn.pack_sequence(sequences, enforce_sorted: false);
67-
AssertPackedSequenceValid(packed_sequence);
68-
packed_sequence.MoveToOuterDisposeScope();
69-
}
70-
AssertPackedSequenceValid(packed_sequence);
71-
packed_sequence.MoveToOtherDisposeScope(otherScope);
72-
}
73-
AssertPackedSequenceValid(packed_sequence);
74-
otherScope.Dispose();
75-
Assert.True(packed_sequence.IsInvalid);
76-
Assert.True(packed_sequence.data.IsInvalid);
77-
}
78-
79-
private static void AssertPackedSequenceValid(nn.utils.rnn.PackedSequence packed_sequence)
80-
{
81-
Assert.False(packed_sequence.IsInvalid);
82-
Assert.False(packed_sequence.batch_sizes.IsInvalid);
83-
Assert.False(packed_sequence.data.IsInvalid);
84-
Assert.False(packed_sequence.sorted_indices.IsInvalid);
85-
Assert.False(packed_sequence.unsorted_indices.IsInvalid);
86-
}
87-
8857
[Fact]
8958
public void TestAutoGradGrad()
9059
{

0 commit comments

Comments
 (0)