Skip to content

Commit 39c3424

Browse files
authored
Implement torch_tensor_negative for taking negative of a tensor (#247)
Implement torch_tensor_negative with tests
1 parent 5522fca commit 39c3424

File tree

5 files changed

+121
-0
lines changed

5 files changed

+121
-0
lines changed

src/ctorch.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,14 @@ torch_tensor_t torch_tensor_add(const torch_tensor_t tensor1,
268268
return output;
269269
}
270270

271+
torch_tensor_t torch_tensor_negative(const torch_tensor_t tensor) {
272+
auto t = reinterpret_cast<torch::Tensor *const>(tensor);
273+
torch::Tensor *output = nullptr;
274+
output = new torch::Tensor;
275+
*output = -*t;
276+
return output;
277+
}
278+
271279
torch_tensor_t torch_tensor_subtract(const torch_tensor_t tensor1,
272280
const torch_tensor_t tensor2) {
273281
auto t1 = reinterpret_cast<torch::Tensor *const>(tensor1);

src/ctorch.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,13 @@ EXPORT_C torch_tensor_t torch_tensor_assign(const torch_tensor_t input);
175175
EXPORT_C torch_tensor_t torch_tensor_add(const torch_tensor_t tensor1,
176176
const torch_tensor_t tensor2);
177177

178+
/**
179+
* Overloads the minus operator for a single Torch Tensor
180+
* @param Tensor to take the negative of
181+
* @return the negative Tensor
182+
*/
183+
EXPORT_C torch_tensor_t torch_tensor_negative(const torch_tensor_t tensor);
184+
178185
/**
179186
* Overloads the subtraction operator for two Torch Tensors
180187
* @param first Tensor to be subtracted

src/ftorch.F90

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,7 @@ end function torch_to_blob_c
181181
end interface
182182

183183
interface operator (-)
184+
module procedure torch_tensor_negative
184185
module procedure torch_tensor_subtract
185186
end interface
186187

@@ -2409,6 +2410,24 @@ end function torch_tensor_add_c
24092410
output%p = torch_tensor_add_c(tensor1%p, tensor2%p)
24102411
end function torch_tensor_add
24112412

2413+
!> Overloads negative operator for a single tensor.
2414+
function torch_tensor_negative(tensor) result(output)
2415+
type(torch_tensor), intent(in) :: tensor
2416+
type(torch_tensor) :: output
2417+
2418+
interface
2419+
function torch_tensor_negative_c(tensor_c) result(output_c) &
2420+
bind(c, name = 'torch_tensor_negative')
2421+
use, intrinsic :: iso_c_binding, only : c_ptr
2422+
implicit none
2423+
type(c_ptr), value, intent(in) :: tensor_c
2424+
type(c_ptr) :: output_c
2425+
end function torch_tensor_negative_c
2426+
end interface
2427+
2428+
output%p = torch_tensor_negative_c(tensor%p)
2429+
end function torch_tensor_negative
2430+
24122431
!> Overloads subtraction operator for two tensors.
24132432
function torch_tensor_subtract(tensor1, tensor2) result(output)
24142433
type(torch_tensor), intent(in) :: tensor1

src/ftorch.fypp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ module ftorch
150150
end interface
151151

152152
interface operator (-)
153+
module procedure torch_tensor_negative
153154
module procedure torch_tensor_subtract
154155
end interface
155156

@@ -632,6 +633,24 @@ contains
632633
output%p = torch_tensor_add_c(tensor1%p, tensor2%p)
633634
end function torch_tensor_add
634635

636+
!> Overloads negative operator for a single tensor.
637+
function torch_tensor_negative(tensor) result(output)
638+
type(torch_tensor), intent(in) :: tensor
639+
type(torch_tensor) :: output
640+
641+
interface
642+
function torch_tensor_negative_c(tensor_c) result(output_c) &
643+
bind(c, name = 'torch_tensor_negative')
644+
use, intrinsic :: iso_c_binding, only : c_ptr
645+
implicit none
646+
type(c_ptr), value, intent(in) :: tensor_c
647+
type(c_ptr) :: output_c
648+
end function torch_tensor_negative_c
649+
end interface
650+
651+
output%p = torch_tensor_negative_c(tensor%p)
652+
end function torch_tensor_negative
653+
635654
!> Overloads subtraction operator for two tensors.
636655
function torch_tensor_subtract(tensor1, tensor2) result(output)
637656
type(torch_tensor), intent(in) :: tensor1

src/test/unit/test_tensor_operator_overloads.pf

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,74 @@ subroutine test_torch_tensor_add()
153153

154154
end subroutine test_torch_tensor_add
155155

156+
@test
157+
subroutine test_torch_tensor_negative()
158+
use FUnit
159+
use ftorch, only: assignment(=), operator(-), ftorch_int, torch_kCPU, torch_kFloat32, &
160+
torch_tensor, torch_tensor_delete, torch_tensor_empty, &
161+
torch_tensor_from_array, torch_tensor_to_array
162+
use ftorch_test_utils, only: assert_allclose
163+
use, intrinsic :: iso_fortran_env, only: sp => real32
164+
use, intrinsic :: iso_c_binding, only : c_associated, c_int64_t
165+
166+
implicit none
167+
168+
! Set working precision for reals
169+
integer, parameter :: wp = sp
170+
171+
type(torch_tensor) :: tensor1, tensor2
172+
integer, parameter :: ndims = 2
173+
integer(ftorch_int), parameter :: tensor_layout(ndims) = [1, 2]
174+
integer(c_int64_t), parameter, dimension(ndims) :: tensor_shape = [2, 3]
175+
integer, parameter :: dtype = torch_kFloat32
176+
integer, parameter :: device_type = torch_kCPU
177+
real(wp), dimension(2,3), target :: in_data
178+
real(wp), dimension(:,:), pointer :: out_data
179+
real(wp), dimension(2,3) :: expected
180+
logical :: test_pass
181+
182+
! Create two arbitrary input arrays
183+
in_data(:,:) = reshape([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], [2, 3])
184+
185+
! Create tensors based off the input array
186+
call torch_tensor_from_array(tensor1, in_data, tensor_layout, device_type)
187+
188+
! Create another empty tensor and assign it to the negative of the first using the overloaded
189+
! negative operator
190+
call torch_tensor_empty(tensor2, ndims, tensor_shape, dtype, device_type)
191+
tensor2 = -tensor1
192+
193+
! Check input arrays are unchanged by the negation
194+
expected(:,:) = reshape([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], [2, 3])
195+
if (.not. assert_allclose(in_data, expected, test_name="test_torch_tensor_negative")) then
196+
call clean_up()
197+
print *, "Error :: first input array was changed during subtraction"
198+
stop 999
199+
end if
200+
201+
! Extract Fortran array from the assigned tensor and compare the data in the tensor to the
202+
! negative of the input array
203+
call torch_tensor_to_array(tensor2, out_data, shape(in_data))
204+
expected(:,:) = -in_data
205+
if (.not. assert_allclose(out_data, expected, test_name="test_torch_tensor_negative")) then
206+
call clean_up()
207+
print *, "Error :: incorrect output from overloaded subtraction operator"
208+
stop 999
209+
end if
210+
211+
call clean_up()
212+
213+
contains
214+
215+
! Subroutine for freeing memory and nullifying pointers used in the unit test
216+
subroutine clean_up()
217+
nullify(out_data)
218+
call torch_tensor_delete(tensor1)
219+
call torch_tensor_delete(tensor2)
220+
end subroutine clean_up
221+
222+
end subroutine test_torch_tensor_negative
223+
156224
@test
157225
subroutine test_torch_tensor_subtract()
158226
use FUnit

0 commit comments

Comments
 (0)