1+ using System . Reflection ;
2+ using TorchSharp ;
3+ using Xunit ;
4+
5+ namespace TorchSharpTest ;
6+
7+ public class TestDisposeScopesPackedSequence
8+ {
9+ [ Fact ]
10+ public void MoveDisposeScope ( )
11+ {
12+ var sequences = CreateTestSequences ( ) ;
13+ torch . nn . utils . rnn . PackedSequence packed_sequence ;
14+ var otherScope = torch . NewDisposeScope ( ) ;
15+ using ( torch . NewDisposeScope ( ) )
16+ {
17+ using ( torch . NewDisposeScope ( ) )
18+ {
19+ packed_sequence = torch . nn . utils . rnn . pack_sequence ( sequences , enforce_sorted : false ) ;
20+ AssertPackedSequenceValid ( packed_sequence ) ;
21+
22+ packed_sequence . MoveToOuterDisposeScope ( ) ;
23+ }
24+ AssertPackedSequenceValid ( packed_sequence ) ;
25+
26+ packed_sequence . MoveToOtherDisposeScope ( otherScope ) ;
27+ }
28+
29+ AssertPackedSequenceValid ( packed_sequence ) ;
30+ otherScope . Dispose ( ) ;
31+
32+ Assert . True ( GetPackedSequenceIsInvalid ( packed_sequence ) ) ;
33+ Assert . True ( packed_sequence . data . IsInvalid ) ;
34+ }
35+
36+ [ Fact ]
37+ public void DisposablesValidityWhenNotSorted ( )
38+ {
39+ var sequences = CreateTestSequences ( ) ;
40+ using var scope = torch . NewDisposeScope ( ) ;
41+ var packed = torch . nn . utils . rnn . pack_sequence ( sequences , enforce_sorted : false ) ;
42+ Assert . Equal ( 5 , scope . DisposablesCount ) ;
43+ AssertPackedSequenceValid ( packed ) ;
44+ }
45+
46+ [ Fact ]
47+ public void DisposablesValidityWhenSorted ( )
48+ {
49+ var sequences = CreateTestSequences ( ) ;
50+ using var scope = torch . NewDisposeScope ( ) ;
51+ var packed = torch . nn . utils . rnn . pack_sequence ( sequences , enforce_sorted : true ) ;
52+ Assert . Equal ( 5 , scope . DisposablesCount ) ;
53+ Assert . False ( GetPackedSequenceIsInvalid ( packed ) ) ;
54+ Assert . False ( packed . batch_sizes . IsInvalid ) ;
55+ Assert . False ( packed . data . IsInvalid ) ;
56+ Assert . True ( packed . sorted_indices . IsInvalid ) ;
57+ Assert . True ( packed . unsorted_indices . IsInvalid ) ;
58+ }
59+
60+ [ Fact ]
61+ public void DisposeScopeStatistics ( )
62+ {
63+ DisposeScopeManager . Statistics . Reset ( ) ;
64+ AssertStatCounts ( 0 , 0 , 0 , 0 , 0 ) ;
65+ var sequences = CreateTestSequences ( ) ;
66+ AssertStatCounts ( 0 , 2 , 0 , 0 , 0 ) ;
67+ var outOfScope = torch . nn . utils . rnn . pack_sequence ( sequences , enforce_sorted : true ) ;
68+ AssertStatCounts ( 0 , 7 , 0 , 0 , 0 ) ;
69+ using var scope = torch . NewDisposeScope ( ) ;
70+ AssertStatCounts ( 0 , 7 , 0 , 0 , 0 ) ;
71+
72+ var inScope = torch . nn . utils . rnn . pack_sequence ( sequences , enforce_sorted : true ) ;
73+ AssertStatCounts ( 5 , 7 , 0 , 0 , 5 ) ;
74+
75+ scope . Attach ( outOfScope ) ;
76+ //Possible subtle bug. When attaching an object that isn't owned by any scope, the count goes negative.
77+ AssertStatCounts ( 5 , 7 , - 1 , 0 , 6 ) ;
78+
79+ scope . Detach ( inScope ) ;
80+ AssertStatCounts ( 5 , 7 , 0 , 0 , 5 ) ;
81+
82+ outOfScope . Dispose ( ) ;
83+ AssertStatCounts ( 5 , 7 , 0 , 5 , 0 ) ;
84+
85+ }
86+
87+ private static void AssertStatCounts ( long createdInScope , long createdOutsideScope , long detachedFrom , long disposedIn , long threadTotalLive )
88+ {
89+ var stats = DisposeScopeManager . Statistics ;
90+ Assert . Equal ( createdInScope , stats . CreatedInScopeCount ) ;
91+ Assert . Equal ( createdOutsideScope , stats . CreatedOutsideScopeCount ) ;
92+ Assert . Equal ( detachedFrom , stats . DetachedFromScopeCount ) ;
93+ Assert . Equal ( disposedIn , stats . DisposedInScopeCount ) ;
94+ Assert . Equal ( threadTotalLive , stats . ThreadTotalLiveCount ) ;
95+ }
96+
97+ private static torch . Tensor [ ] CreateTestSequences ( )
98+ {
99+ return new [ ]
100+ {
101+ torch . tensor ( new long [ ] { 1 , 2 , 3 , 4 } ) ,
102+ torch . tensor ( new long [ ] { 5 , 6 } ) ,
103+ } ;
104+ }
105+
106+ private static void AssertPackedSequenceValid ( torch . nn . utils . rnn . PackedSequence packed_sequence )
107+ {
108+ Assert . False ( GetPackedSequenceIsInvalid ( packed_sequence ) ) ;
109+ Assert . False ( packed_sequence . batch_sizes . IsInvalid ) ;
110+ Assert . False ( packed_sequence . data . IsInvalid ) ;
111+ Assert . False ( packed_sequence . sorted_indices . IsInvalid ) ;
112+ Assert . False ( packed_sequence . unsorted_indices . IsInvalid ) ;
113+ }
114+
115+ private static bool GetPackedSequenceIsInvalid ( torch . nn . utils . rnn . PackedSequence packed_sequence )
116+ {
117+ //HACK: reflection to avoid exposing internal method IsInvalid in API.
118+ var getter = typeof ( torch . nn . utils . rnn . PackedSequence ) . GetProperty ( "IsInvalid" , BindingFlags . Instance | BindingFlags . NonPublic ) ! ;
119+ return ( bool ) getter . GetValue ( packed_sequence ) ! ;
120+ }
121+ }
0 commit comments