Skip to content

Commit 5522fca

Browse files
authored
Move input array checks out of autograd example (#231)
* Add missing summary of autograd to examples README * Move cleanup subroutine and checks that input arrays aren't changed by arithmetic * Call cleanup if assertion fails in operator overload tests * Call cleanup if assertion fails in constructor tests
1 parent d139651 commit 5522fca

File tree

4 files changed

+396
-145
lines changed

4 files changed

+396
-145
lines changed

examples/6_Autograd/autograd.f90

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -67,28 +67,9 @@ program example
6767
stop 999
6868
end if
6969

70-
! Check first input array is unchanged by the arithmetic operations
71-
expected(:,1) = [2.0_wp, 3.0_wp]
72-
test_pass = assert_allclose(in_data1, expected, test_name="torch_tensor_to_array", rtol=1e-5)
73-
if (.not. test_pass) then
74-
call clean_up()
75-
print *, "Error :: in_data1 was changed during arithmetic operations"
76-
stop 999
77-
end if
78-
79-
! Check second input array is unchanged by the arithmetic operations
80-
expected(:,1) = [6.0_wp, 4.0_wp]
81-
test_pass = assert_allclose(in_data2, expected, test_name="torch_tensor_to_array", rtol=1e-5)
82-
if (.not. test_pass) then
83-
call clean_up()
84-
print *, "Error :: in_data2 was changed during arithmetic operations"
85-
stop 999
86-
end if
87-
8870
! Back-propagation
8971
! TODO: Requires API extension
9072

91-
! Cleanup
9273
call clean_up()
9374
write (*,*) "Autograd example ran successfully"
9475

examples/README.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,11 @@ This directory contains a number of examples of how to use the library:
2323
This example demonstrates how to structure code in these cases, separating reading
2424
in of a net from the call to the forward pass.
2525

26+
6. Autograd
27+
- **This example is currently under development.** Eventually, it will
28+
demonstrate automatic differentation in FTorch by leveraging PyTorch's
29+
Autograd module.
30+
2631
To run select examples as integration tests, use the CMake argument
2732
```
2833
-DCMAKE_BUILD_TESTS=TRUE

src/test/unit/test_constructors.pf

Lines changed: 84 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -81,12 +81,21 @@ subroutine test_torch_tensor_zeros()
8181
! Check that the tensor values are all zero
8282
expected(:,:) = 0.0
8383
test_pass = assert_allclose(out_data, expected, test_name="test_torch_tensor_zeros")
84-
@assertTrue(test_pass)
85-
@assertEqual(shape(out_data), shape(expected))
84+
if (.not. test_pass) then
85+
call clean_up()
86+
print *, "Error :: incorrect output from torch_tensor_zeros subroutine"
87+
stop 999
88+
end if
8689

87-
! Cleanup
88-
nullify(out_data)
89-
call torch_tensor_delete(tensor)
90+
call clean_up()
91+
92+
contains
93+
94+
! Subroutine for freeing memory and nullifying pointers used in the unit test
95+
subroutine clean_up()
96+
nullify(out_data)
97+
call torch_tensor_delete(tensor)
98+
end subroutine clean_up
9099

91100
end subroutine test_torch_tensor_zeros
92101

@@ -132,12 +141,21 @@ subroutine test_torch_tensor_ones()
132141
! Check that the tensor values are all one
133142
expected(:,:) = 1.0
134143
test_pass = assert_allclose(out_data, expected, test_name="test_torch_tensor_ones")
135-
@assertTrue(test_pass)
136-
@assertEqual(shape(out_data), shape(expected))
144+
if (.not. test_pass) then
145+
call clean_up()
146+
print *, "Error :: incorrect output from torch_tensor_ones subroutine"
147+
stop 999
148+
end if
137149

138-
! Cleanup
139-
nullify(out_data)
140-
call torch_tensor_delete(tensor)
150+
call clean_up()
151+
152+
contains
153+
154+
! Subroutine for freeing memory and nullifying pointers used in the unit test
155+
subroutine clean_up()
156+
nullify(out_data)
157+
call torch_tensor_delete(tensor)
158+
end subroutine clean_up
141159

142160
end subroutine test_torch_tensor_ones
143161

@@ -184,12 +202,21 @@ subroutine test_torch_from_array_1d()
184202
! Compare the data in the tensor to the input data
185203
expected(:) = in_data
186204
test_pass = assert_allclose(out_data, expected, test_name="test_torch_tensor_from_array")
187-
@assertTrue(test_pass)
188-
@assertEqual(shape(out_data), shape(expected))
205+
if (.not. test_pass) then
206+
call clean_up()
207+
print *, "Error :: incorrect output from torch_tensor_from_array subroutine"
208+
stop 999
209+
end if
189210

190-
! Cleanup
191-
nullify(out_data)
192-
call torch_tensor_delete(tensor)
211+
call clean_up()
212+
213+
contains
214+
215+
! Subroutine for freeing memory and nullifying pointers used in the unit test
216+
subroutine clean_up()
217+
nullify(out_data)
218+
call torch_tensor_delete(tensor)
219+
end subroutine clean_up
193220

194221
end subroutine test_torch_from_array_1d
195222

@@ -236,12 +263,21 @@ subroutine test_torch_from_array_2d()
236263
! Compare the data in the tensor to the input data
237264
expected(:,:) = in_data
238265
test_pass = assert_allclose(out_data, expected, test_name="test_torch_tensor_from_array")
239-
@assertTrue(test_pass)
240-
@assertEqual(shape(out_data), shape(expected))
266+
if (.not. test_pass) then
267+
call clean_up()
268+
print *, "Error :: incorrect output from torch_tensor_from_array subroutine"
269+
stop 999
270+
end if
241271

242-
! Cleanup
243-
nullify(out_data)
244-
call torch_tensor_delete(tensor)
272+
call clean_up()
273+
274+
contains
275+
276+
! Subroutine for freeing memory and nullifying pointers used in the unit test
277+
subroutine clean_up()
278+
nullify(out_data)
279+
call torch_tensor_delete(tensor)
280+
end subroutine clean_up
245281

246282
end subroutine test_torch_from_array_2d
247283

@@ -286,12 +322,21 @@ subroutine test_torch_from_array_3d()
286322
! Compare the data in the tensor to the input data
287323
expected(:,:,:) = in_data
288324
test_pass = assert_allclose(out_data, expected, test_name="test_torch_tensor_from_array")
289-
@assertTrue(test_pass)
290-
@assertEqual(shape(out_data), shape(expected))
325+
if (.not. test_pass) then
326+
call clean_up()
327+
print *, "Error :: incorrect output from torch_tensor_from_array subroutine"
328+
stop 999
329+
end if
291330

292-
! Cleanup
293-
nullify(out_data)
294-
call torch_tensor_delete(tensor)
331+
call clean_up()
332+
333+
contains
334+
335+
! Subroutine for freeing memory and nullifying pointers used in the unit test
336+
subroutine clean_up()
337+
nullify(out_data)
338+
call torch_tensor_delete(tensor)
339+
end subroutine clean_up
295340

296341
end subroutine test_torch_from_array_3d
297342

@@ -340,11 +385,20 @@ subroutine test_torch_from_blob()
340385
! Compare the data in the tensor to the input data
341386
expected(:,:) = in_data
342387
test_pass = assert_allclose(out_data, expected, test_name="test_torch_tensor_from_blob")
343-
@assertTrue(test_pass)
344-
@assertEqual(shape(out_data), shape(expected))
388+
if (.not. test_pass) then
389+
call clean_up()
390+
print *, "Error :: incorrect output from torch_tensor_from_array subroutine"
391+
stop 999
392+
end if
345393

346-
! Cleanup
347-
nullify(out_data)
348-
call torch_tensor_delete(tensor)
394+
call clean_up()
395+
396+
contains
397+
398+
! Subroutine for freeing memory and nullifying pointers used in the unit test
399+
subroutine clean_up()
400+
nullify(out_data)
401+
call torch_tensor_delete(tensor)
402+
end subroutine clean_up
349403

350404
end subroutine test_torch_from_blob

0 commit comments

Comments
 (0)