Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement torch_tensor_negative for taking negative of a tensor #247

Merged
merged 4 commits into from
Jan 20, 2025
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/ctorch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,14 @@ torch_tensor_t torch_tensor_add(const torch_tensor_t tensor1,
return output;
}

torch_tensor_t torch_tensor_negative(const torch_tensor_t tensor) {
auto t = reinterpret_cast<torch::Tensor *const>(tensor);
torch::Tensor *output = nullptr;
output = new torch::Tensor;
*output = -*t;
return output;
}

torch_tensor_t torch_tensor_subtract(const torch_tensor_t tensor1,
const torch_tensor_t tensor2) {
auto t1 = reinterpret_cast<torch::Tensor *const>(tensor1);
Expand Down
7 changes: 7 additions & 0 deletions src/ctorch.h
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,13 @@ EXPORT_C torch_tensor_t torch_tensor_assign(const torch_tensor_t input);
EXPORT_C torch_tensor_t torch_tensor_add(const torch_tensor_t tensor1,
const torch_tensor_t tensor2);

/**
* Overloads the minus operator for a single Torch Tensor
* @param Tensor to take the negative of
* @return the negative Tensor
*/
EXPORT_C torch_tensor_t torch_tensor_negative(const torch_tensor_t tensor);

/**
* Overloads the subtraction operator for two Torch Tensors
* @param first Tensor to be subtracted
Expand Down
19 changes: 19 additions & 0 deletions src/ftorch.F90
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ end function torch_to_blob_c
end interface

interface operator (-)
module procedure torch_tensor_negative
module procedure torch_tensor_subtract
end interface

Expand Down Expand Up @@ -2409,6 +2410,24 @@ end function torch_tensor_add_c
output%p = torch_tensor_add_c(tensor1%p, tensor2%p)
end function torch_tensor_add

!> Overloads negative operator for a single tensor.
function torch_tensor_negative(tensor) result(output)
type(torch_tensor), intent(in) :: tensor
type(torch_tensor) :: output

interface
function torch_tensor_negative_c(tensor_c) result(output_c) &
bind(c, name = 'torch_tensor_negative')
use, intrinsic :: iso_c_binding, only : c_ptr
implicit none
type(c_ptr), value, intent(in) :: tensor_c
type(c_ptr) :: output_c
end function torch_tensor_negative_c
end interface

output%p = torch_tensor_negative_c(tensor%p)
end function torch_tensor_negative

!> Overloads subtraction operator for two tensors.
function torch_tensor_subtract(tensor1, tensor2) result(output)
type(torch_tensor), intent(in) :: tensor1
Expand Down
19 changes: 19 additions & 0 deletions src/ftorch.fypp
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ module ftorch
end interface

interface operator (-)
module procedure torch_tensor_negative
module procedure torch_tensor_subtract
end interface

Expand Down Expand Up @@ -632,6 +633,24 @@ contains
output%p = torch_tensor_add_c(tensor1%p, tensor2%p)
end function torch_tensor_add

!> Overloads negative operator for a single tensor.
function torch_tensor_negative(tensor) result(output)
type(torch_tensor), intent(in) :: tensor
type(torch_tensor) :: output

interface
function torch_tensor_negative_c(tensor_c) result(output_c) &
bind(c, name = 'torch_tensor_negative')
use, intrinsic :: iso_c_binding, only : c_ptr
implicit none
type(c_ptr), value, intent(in) :: tensor_c
type(c_ptr) :: output_c
end function torch_tensor_negative_c
end interface

output%p = torch_tensor_negative_c(tensor%p)
end function torch_tensor_negative

!> Overloads subtraction operator for two tensors.
function torch_tensor_subtract(tensor1, tensor2) result(output)
type(torch_tensor), intent(in) :: tensor1
Expand Down
68 changes: 68 additions & 0 deletions src/test/unit/test_tensor_operator_overloads.pf
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,74 @@ subroutine test_torch_tensor_add()

end subroutine test_torch_tensor_add

@test
subroutine test_torch_tensor_negative()
use FUnit
use ftorch, only: assignment(=), operator(-), ftorch_int, torch_kCPU, torch_kFloat32, &
torch_tensor, torch_tensor_delete, torch_tensor_empty, &
torch_tensor_from_array, torch_tensor_to_array
use ftorch_test_utils, only: assert_allclose
use, intrinsic :: iso_fortran_env, only: sp => real32
use, intrinsic :: iso_c_binding, only : c_associated, c_int64_t

implicit none

! Set working precision for reals
integer, parameter :: wp = sp

type(torch_tensor) :: tensor1, tensor2
integer, parameter :: ndims = 2
integer(ftorch_int), parameter :: tensor_layout(ndims) = [1, 2]
integer(c_int64_t), parameter, dimension(ndims) :: tensor_shape = [2, 3]
integer, parameter :: dtype = torch_kFloat32
integer, parameter :: device_type = torch_kCPU
real(wp), dimension(2,3), target :: in_data
real(wp), dimension(:,:), pointer :: out_data
real(wp), dimension(2,3) :: expected
logical :: test_pass

! Create two arbitrary input arrays
in_data(:,:) = reshape([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], [2, 3])

! Create tensors based off the input array
call torch_tensor_from_array(tensor1, in_data, tensor_layout, device_type)

! Create another empty tensor and assign it to the negative of the first using the overloaded
! negative operator
call torch_tensor_empty(tensor2, ndims, tensor_shape, dtype, device_type)
tensor2 = -tensor1

! Check input arrays are unchanged by the subtraction
expected(:,:) = reshape([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], [2, 3])
if (.not. assert_allclose(in_data, expected, test_name="test_torch_tensor_negative")) then
call clean_up()
print *, "Error :: first input array was changed during subtraction"
stop 999
end if

! Extract Fortran array from the assigned tensor and compare the data in the tensor to the
! negative of the input array
call torch_tensor_to_array(tensor2, out_data, shape(in_data))
expected(:,:) = -in_data
if (.not. assert_allclose(out_data, expected, test_name="test_torch_tensor_subtract")) then
call clean_up()
print *, "Error :: incorrect output from overloaded subtraction operator"
stop 999
end if

call clean_up()

contains

! Subroutine for freeing memory and nullifying pointers used in the unit test
subroutine clean_up()
nullify(out_data)
call torch_tensor_delete(tensor1)
call torch_tensor_delete(tensor2)
end subroutine clean_up

end subroutine test_torch_tensor_negative

@test
subroutine test_torch_tensor_subtract()
use FUnit
Expand Down
Loading