Skip to content

Commit 53012fe

Browse files
committed
PackedSequence internal tensors manages internally to optimize scope management
1 parent 256c421 commit 53012fe

File tree

2 files changed

+11
-27
lines changed

2 files changed

+11
-27
lines changed

src/TorchSharp/NN/Utils/PackedSequence.cs

+4-20
Original file line numberDiff line numberDiff line change
@@ -86,10 +86,10 @@ protected override bool ReleaseHandle()
8686
internal PackedSequence(HType handle)
8787
{
8888
this.handle = handle;
89-
this.data = new Tensor(THSNN_PackedSequence_data(handle));
90-
this.batch_sizes = new Tensor(THSNN_PackedSequence_batch_sizes(handle));
91-
this.sorted_indices = new Tensor(THSNN_PackedSequence_sorted_indices(handle));
92-
this.unsorted_indices = new Tensor(THSNN_PackedSequence_unsorted_indices(handle));
89+
this.data = new Tensor(THSNN_PackedSequence_data(handle)).DetachFromDisposeScope();
90+
this.batch_sizes = new Tensor(THSNN_PackedSequence_batch_sizes(handle)).DetachFromDisposeScope();
91+
this.sorted_indices = new Tensor(THSNN_PackedSequence_sorted_indices(handle)).DetachFromDisposeScope();
92+
this.unsorted_indices = new Tensor(THSNN_PackedSequence_unsorted_indices(handle)).DetachFromDisposeScope();
9393
OwningDisposeScope = DisposeScopeManager.ThreadSingleton.RegisterOnCurrentDisposeScope(this);
9494
}
9595

@@ -119,10 +119,6 @@ public void Dispose()
119119
/// <returns>The same PackedSequence that the method was called on</returns>
120120
public PackedSequence MoveToOuterDisposeScope()
121121
{
122-
OwningDisposeScope?.MoveToOuter(this.data);
123-
OwningDisposeScope?.MoveToOuter(this.batch_sizes);
124-
OwningDisposeScope?.MoveToOuter(this.sorted_indices);
125-
OwningDisposeScope?.MoveToOuter(this.unsorted_indices);
126122
OwningDisposeScope?.MoveToOuter(this);
127123
return this;
128124
}
@@ -133,10 +129,6 @@ public PackedSequence MoveToOuterDisposeScope()
133129
/// <returns>The same PackedSequence that the method was called on</returns>
134130
public PackedSequence DetachFromDisposeScope()
135131
{
136-
OwningDisposeScope?.Detach(this.data);
137-
OwningDisposeScope?.Detach(this.batch_sizes);
138-
OwningDisposeScope?.Detach(this.sorted_indices);
139-
OwningDisposeScope?.Detach(this.unsorted_indices);
140132
OwningDisposeScope?.Detach(this);
141133
return this;
142134
}
@@ -149,17 +141,9 @@ public PackedSequence MoveToOtherDisposeScope(PackedSequence other)
149141
public PackedSequence MoveToOtherDisposeScope(DisposeScope other)
150142
{
151143
if (OwningDisposeScope == null && other != null) {
152-
other.Attach(this.data);
153-
other.Attach(this.batch_sizes);
154-
other.Attach(this.sorted_indices);
155-
other.Attach(this.unsorted_indices);
156144
other.Attach(this);
157145
}
158146
else {
159-
OwningDisposeScope?.MoveToOther(other, this.data);
160-
OwningDisposeScope?.MoveToOther(other, this.batch_sizes);
161-
OwningDisposeScope?.MoveToOther(other, this.sorted_indices);
162-
OwningDisposeScope?.MoveToOther(other, this.unsorted_indices);
163147
OwningDisposeScope?.MoveToOther(other, this);
164148
}
165149
return this;

test/TorchSharpTest/TestDisposeScopesPackedSequence.cs

+7-7
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ public void DisposablesValidityWhenNotSorted()
3939
var sequences = CreateTestSequences();
4040
using var scope = torch.NewDisposeScope();
4141
var packed = torch.nn.utils.rnn.pack_sequence(sequences, enforce_sorted: false);
42-
Assert.Equal(5, scope.DisposablesCount);
42+
Assert.Equal(1, scope.DisposablesCount);
4343
AssertPackedSequenceValid(packed);
4444
}
4545

@@ -49,7 +49,7 @@ public void DisposablesValidityWhenSorted()
4949
var sequences = CreateTestSequences();
5050
using var scope = torch.NewDisposeScope();
5151
var packed = torch.nn.utils.rnn.pack_sequence(sequences, enforce_sorted: true);
52-
Assert.Equal(5, scope.DisposablesCount);
52+
Assert.Equal(1, scope.DisposablesCount);
5353
Assert.False(GetPackedSequenceIsInvalid(packed));
5454
Assert.False(packed.batch_sizes.IsInvalid);
5555
Assert.False(packed.data.IsInvalid);
@@ -70,17 +70,17 @@ public void DisposeScopeStatistics()
7070
AssertStatCounts(0, 7, 0, 0, 0);
7171

7272
var inScope = torch.nn.utils.rnn.pack_sequence(sequences, enforce_sorted: true);
73-
AssertStatCounts(5, 7, 0, 0, 5);
73+
AssertStatCounts(5, 7, 4, 0, 1);
7474

7575
scope.Attach(outOfScope);
76-
//Possible subtle bug. When attaching an object that isn't owned by any scope, the count goes negative.
77-
AssertStatCounts( 5, 7, -1, 0, 6);
76+
//Possible subtle bug. When attaching an object that isn't owned by any scope, the count subtracts.
77+
AssertStatCounts( 5, 7, 3, 0, 2);
7878

7979
scope.Detach(inScope);
80-
AssertStatCounts( 5, 7, 0, 0, 5);
80+
AssertStatCounts( 5, 7, 4, 0, 1);
8181

8282
outOfScope.Dispose();
83-
AssertStatCounts( 5, 7, 0, 5, 0);
83+
AssertStatCounts( 5, 7, 4, 5, -4);
8484

8585
}
8686

0 commit comments

Comments
 (0)