forked from oracle/graal
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathr_mutual.cpp
30 lines (27 loc) · 823 Bytes
/
r_mutual.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
#include <RcppArmadillo.h>
#include <cmath>
//[[Rcpp::depends(RcppArmadillo)]]
using namespace Rcpp;
// [[Rcpp::export]]
double mutual_cpp(arma::mat joint_dist){
joint_dist = joint_dist / sum(sum(joint_dist));
double mutual_information = 0;
int num_rows = joint_dist.n_rows;
int num_cols = joint_dist.n_cols;
arma::mat colsums = sum(joint_dist, 0);
arma::mat rowsums = sum(joint_dist, 1);
for(int i = 0; i < num_rows; ++i){
for(int j = 0; j < num_cols; ++j){
double temp = log((joint_dist(i, j) / (colsums[j] * rowsums[i])));
if(!std::isfinite(temp)){
temp = 0;
}
mutual_information += joint_dist(i, j) * temp;
}
}
return mutual_information;
}
// [[Rcpp::export]]
List mutual_test(arma::mat joint_dist){
return List::create(Named("sum") = sum(joint_dist));
}