Skip to content

Commit 32c9387

Browse files
feat(gpu): enable division in high level api
1 parent bede76b commit 32c9387

File tree

1 file changed

+27
-9
lines changed
  • tfhe/src/high_level_api/integers/signed

1 file changed

+27
-9
lines changed

tfhe/src/high_level_api/integers/signed/ops.rs

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -514,9 +514,17 @@ where
514514
)
515515
}
516516
#[cfg(feature = "gpu")]
517-
InternalServerKey::Cuda(_) => {
518-
panic!("Cuda devices does not support division yet")
519-
}
517+
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
518+
let (q, r) = cuda_key.key.key.div_rem(
519+
&*self.ciphertext.on_gpu(streams),
520+
&*rhs.ciphertext.on_gpu(streams),
521+
streams,
522+
);
523+
(
524+
FheInt::<Id>::new(q, cuda_key.tag.clone()),
525+
FheInt::<Id>::new(r, cuda_key.tag.clone()),
526+
)
527+
}),
520528
})
521529
}
522530
}
@@ -847,9 +855,14 @@ generic_integer_impl_operation!(
847855
FheInt::new(inner_result, cpu_key.tag.clone())
848856
},
849857
#[cfg(feature = "gpu")]
850-
InternalServerKey::Cuda(_cuda_key) => {
851-
panic!("Division '/' is not yet supported by Cuda devices")
852-
}
858+
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
859+
let inner_result =
860+
cuda_key
861+
.key
862+
.key
863+
.div(&*lhs.ciphertext.on_gpu(streams), &*rhs.ciphertext.on_gpu(streams), streams);
864+
FheInt::new(inner_result, cuda_key.tag.clone())
865+
}),
853866
})
854867
}
855868
},
@@ -893,9 +906,14 @@ generic_integer_impl_operation!(
893906
FheInt::new(inner_result, cpu_key.tag.clone())
894907
},
895908
#[cfg(feature = "gpu")]
896-
InternalServerKey::Cuda(_cuda_key) => {
897-
panic!("Remainder/Modulo '%' is not yet supported by Cuda devices")
898-
}
909+
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
910+
let inner_result =
911+
cuda_key
912+
.key
913+
.key
914+
.rem(&*lhs.ciphertext.on_gpu(streams), &*rhs.ciphertext.on_gpu(streams), streams);
915+
FheInt::new(inner_result, cuda_key.tag.clone())
916+
}),
899917
})
900918
}
901919
},

0 commit comments

Comments
 (0)