Skip to content

Commit

Permalink
Implement torch_tensor_negative for taking negative of a tensor (#247)
Browse files Browse the repository at this point in the history
Implement torch_tensor_negative with tests
  • Loading branch information
jwallwork23 authored Jan 20, 2025
1 parent 5522fca commit 39c3424
Show file tree
Hide file tree
Showing 5 changed files with 121 additions and 0 deletions.
8 changes: 8 additions & 0 deletions src/ctorch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,14 @@ torch_tensor_t torch_tensor_add(const torch_tensor_t tensor1,
return output;
}

torch_tensor_t torch_tensor_negative(const torch_tensor_t tensor) {
auto t = reinterpret_cast<torch::Tensor *const>(tensor);
torch::Tensor *output = nullptr;
output = new torch::Tensor;
*output = -*t;
return output;
}

torch_tensor_t torch_tensor_subtract(const torch_tensor_t tensor1,
const torch_tensor_t tensor2) {
auto t1 = reinterpret_cast<torch::Tensor *const>(tensor1);
Expand Down
7 changes: 7 additions & 0 deletions src/ctorch.h
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,13 @@ EXPORT_C torch_tensor_t torch_tensor_assign(const torch_tensor_t input);
EXPORT_C torch_tensor_t torch_tensor_add(const torch_tensor_t tensor1,
const torch_tensor_t tensor2);

/**
* Overloads the minus operator for a single Torch Tensor
* @param Tensor to take the negative of
* @return the negative Tensor
*/
EXPORT_C torch_tensor_t torch_tensor_negative(const torch_tensor_t tensor);

/**
* Overloads the subtraction operator for two Torch Tensors
* @param first Tensor to be subtracted
Expand Down
19 changes: 19 additions & 0 deletions src/ftorch.F90
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ end function torch_to_blob_c
end interface

interface operator (-)
module procedure torch_tensor_negative
module procedure torch_tensor_subtract
end interface

Expand Down Expand Up @@ -2409,6 +2410,24 @@ end function torch_tensor_add_c
output%p = torch_tensor_add_c(tensor1%p, tensor2%p)
end function torch_tensor_add

!> Overloads negative operator for a single tensor.
function torch_tensor_negative(tensor) result(output)
type(torch_tensor), intent(in) :: tensor
type(torch_tensor) :: output

interface
function torch_tensor_negative_c(tensor_c) result(output_c) &
bind(c, name = 'torch_tensor_negative')
use, intrinsic :: iso_c_binding, only : c_ptr
implicit none
type(c_ptr), value, intent(in) :: tensor_c
type(c_ptr) :: output_c
end function torch_tensor_negative_c
end interface

output%p = torch_tensor_negative_c(tensor%p)
end function torch_tensor_negative

!> Overloads subtraction operator for two tensors.
function torch_tensor_subtract(tensor1, tensor2) result(output)
type(torch_tensor), intent(in) :: tensor1
Expand Down
19 changes: 19 additions & 0 deletions src/ftorch.fypp
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ module ftorch
end interface

interface operator (-)
module procedure torch_tensor_negative
module procedure torch_tensor_subtract
end interface

Expand Down Expand Up @@ -632,6 +633,24 @@ contains
output%p = torch_tensor_add_c(tensor1%p, tensor2%p)
end function torch_tensor_add

!> Overloads negative operator for a single tensor.
function torch_tensor_negative(tensor) result(output)
type(torch_tensor), intent(in) :: tensor
type(torch_tensor) :: output

interface
function torch_tensor_negative_c(tensor_c) result(output_c) &
bind(c, name = 'torch_tensor_negative')
use, intrinsic :: iso_c_binding, only : c_ptr
implicit none
type(c_ptr), value, intent(in) :: tensor_c
type(c_ptr) :: output_c
end function torch_tensor_negative_c
end interface

output%p = torch_tensor_negative_c(tensor%p)
end function torch_tensor_negative

!> Overloads subtraction operator for two tensors.
function torch_tensor_subtract(tensor1, tensor2) result(output)
type(torch_tensor), intent(in) :: tensor1
Expand Down
68 changes: 68 additions & 0 deletions src/test/unit/test_tensor_operator_overloads.pf
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,74 @@ subroutine test_torch_tensor_add()

end subroutine test_torch_tensor_add

@test
subroutine test_torch_tensor_negative()
use FUnit
use ftorch, only: assignment(=), operator(-), ftorch_int, torch_kCPU, torch_kFloat32, &
torch_tensor, torch_tensor_delete, torch_tensor_empty, &
torch_tensor_from_array, torch_tensor_to_array
use ftorch_test_utils, only: assert_allclose
use, intrinsic :: iso_fortran_env, only: sp => real32
use, intrinsic :: iso_c_binding, only : c_associated, c_int64_t

implicit none

! Set working precision for reals
integer, parameter :: wp = sp

type(torch_tensor) :: tensor1, tensor2
integer, parameter :: ndims = 2
integer(ftorch_int), parameter :: tensor_layout(ndims) = [1, 2]
integer(c_int64_t), parameter, dimension(ndims) :: tensor_shape = [2, 3]
integer, parameter :: dtype = torch_kFloat32
integer, parameter :: device_type = torch_kCPU
real(wp), dimension(2,3), target :: in_data
real(wp), dimension(:,:), pointer :: out_data
real(wp), dimension(2,3) :: expected
logical :: test_pass

! Create two arbitrary input arrays
in_data(:,:) = reshape([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], [2, 3])

! Create tensors based off the input array
call torch_tensor_from_array(tensor1, in_data, tensor_layout, device_type)

! Create another empty tensor and assign it to the negative of the first using the overloaded
! negative operator
call torch_tensor_empty(tensor2, ndims, tensor_shape, dtype, device_type)
tensor2 = -tensor1

! Check input arrays are unchanged by the negation
expected(:,:) = reshape([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], [2, 3])
if (.not. assert_allclose(in_data, expected, test_name="test_torch_tensor_negative")) then
call clean_up()
print *, "Error :: first input array was changed during subtraction"
stop 999
end if

! Extract Fortran array from the assigned tensor and compare the data in the tensor to the
! negative of the input array
call torch_tensor_to_array(tensor2, out_data, shape(in_data))
expected(:,:) = -in_data
if (.not. assert_allclose(out_data, expected, test_name="test_torch_tensor_negative")) then
call clean_up()
print *, "Error :: incorrect output from overloaded subtraction operator"
stop 999
end if

call clean_up()

contains

! Subroutine for freeing memory and nullifying pointers used in the unit test
subroutine clean_up()
nullify(out_data)
call torch_tensor_delete(tensor1)
call torch_tensor_delete(tensor2)
end subroutine clean_up

end subroutine test_torch_tensor_negative

@test
subroutine test_torch_tensor_subtract()
use FUnit
Expand Down

0 comments on commit 39c3424

Please sign in to comment.