Skip to content

mutual information and related methods #89

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 185 additions & 0 deletions src/main/java/net/imglib2/algorithm/stats/InformationMetrics.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
package net.imglib2.algorithm.stats;

import java.util.ArrayList;
import java.util.List;

import net.imglib2.Cursor;
import net.imglib2.IterableInterval;
import net.imglib2.histogram.BinMapper1d;
import net.imglib2.histogram.Histogram1d;
import net.imglib2.histogram.HistogramNd;
import net.imglib2.histogram.Real1dBinMapper;
import net.imglib2.type.numeric.RealType;
import net.imglib2.type.numeric.integer.LongType;
import net.imglib2.view.IntervalView;
import net.imglib2.view.Views;

/**
* This class provides method for computing information metrics (entropy, mutual information)
* for imglib2.
*
* @author John Bogovic
*/
public class InformationMetrics
{

/**
* Returns the normalized mutual information of the inputs
* @param rai the RandomAccessibleInterval
* @param ra the RandomAccessible
Comment on lines +28 to +29
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Inconsistent with actual types of parameters in method signature

* @return the normalized mutual information
*/
public static <T extends RealType< T >> double normalizedMutualInformation(
IterableInterval< T > dataA,
IterableInterval< T > dataB,
double histmin, double histmax, int numBins )
{
HistogramNd< T > jointHist = jointHistogram( dataA, dataB, histmin, histmax, numBins );

double HA = marginalEntropy( jointHist, 0 );
double HB = marginalEntropy( jointHist, 1 );
double HAB = entropy( jointHist );
return ( HA + HB ) / HAB;
}

/**
* Returns the normalized mutual information of the inputs
* @param rai the RandomAccessibleInterval
* @param ra the RandomAccessible
* @return the normalized mutual information
*/
public static <T extends RealType< T >> double mutualInformation(
IterableInterval< T > dataA,
IterableInterval< T > dataB,
double histmin, double histmax, int numBins )
{
HistogramNd< T > jointHist = jointHistogram( dataA, dataB, histmin, histmax, numBins );

double HA = marginalEntropy( jointHist, 0 );
double HB = marginalEntropy( jointHist, 1 );
double HAB = entropy( jointHist );

return HA + HB - HAB;
Comment on lines +68 to +72
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should call mutualInformation( jointHist ) instead of duplicating the code.

}

public static <T extends RealType< T >> HistogramNd<T> jointHistogram(
IterableInterval< T > dataA,
IterableInterval< T > dataB,
double histmin, double histmax, int numBins )
{
Real1dBinMapper<T> binMapper = new Real1dBinMapper<T>( histmin, histmax, numBins, false );
ArrayList<BinMapper1d<T>> binMappers = new ArrayList<BinMapper1d<T>>( 2 );
binMappers.add( binMapper );
binMappers.add( binMapper );

List<Iterable<T>> data = new ArrayList<Iterable<T>>( 2 );
data.add( dataA );
data.add( dataB );
return new HistogramNd<T>( data, binMappers );
}

/**
* Returns the joint entropy of the inputs
* @param rai the RandomAccessibleInterval
* @param ra the RandomAccessible
* @return the joint entropy
*/
public static <T extends RealType< T >> double jointEntropy(
IterableInterval< T > dataA,
IterableInterval< T > dataB,
double histmin, double histmax, int numBins )
{
return entropy( jointHistogram( dataA, dataB, histmin, histmax, numBins ));
}

/**
* Returns the entropy of the input.
*
* @param data the data
* @return the entropy
*/
public static <T extends RealType< T >> double entropy(
IterableInterval< T > data,
double histmin, double histmax, int numBins )
{
Real1dBinMapper<T> binMapper = new Real1dBinMapper<T>(
histmin, histmax, numBins, false );
final Histogram1d<T> hist = new Histogram1d<T>( binMapper );
hist.countData( data );

return entropy( hist );
}

/**
* Computes the entropy of the input 1d histogram.
* @param hist the histogram
* @return the entropy
*/
public static < T > double entropy( Histogram1d< T > hist )
{
double entropy = 0.0;
for( int i = 0; i < hist.getBinCount(); i++ )
{
double p = hist.relativeFrequency( i, false );
if( p > 0 )
entropy -= p * Math.log( p );

}
return entropy;
}

/**
* Computes the entropy of the input nd histogram.
* @param hist the histogram
* @return the entropy
*/
public static < T > double entropy( HistogramNd< T > hist )
{
double entropy = 0.0;
Cursor< LongType > hc = hist.cursor();
long[] pos = new long[ hc.numDimensions() ];

while( hc.hasNext() )
{
hc.fwd();
hc.localize( pos );
double p = hist.relativeFrequency( pos, false );
if( p > 0 )
entropy -= p * Math.log( p );

}
return entropy;
}

public static < T > double marginalEntropy( HistogramNd< T > hist, int dim )
{

final long ni = hist.dimension( dim );
final long total = hist.valueCount();
long count = 0;
double entropy = 0.0;
long ctot = 0;
for( int i = 0; i < ni; i++ )
{
count = subHistCount( hist, dim, i );
ctot += count;
double p = 1.0 * count / total;

if( p > 0 )
entropy -= p * Math.log( p );
}
Comment on lines +195 to +203
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this something we can run multi-threaded using LoopBuilder? See also #83.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea, I'll check 👍 thanks @imagejan

return entropy;
}

private static < T > long subHistCount( HistogramNd< T > hist, int dim, int pos )
{
long count = 0;
IntervalView< LongType > hs = Views.hyperSlice( hist, dim, pos );
Cursor< LongType > c = hs.cursor();
while( c.hasNext() )
{
count += c.next().get();
}
return count;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
package net.imglib2.algorithm.stats;

import static org.junit.Assert.assertEquals;

import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;

import net.imglib2.img.Img;
import net.imglib2.img.array.ArrayImgs;
import net.imglib2.type.numeric.integer.IntType;

public class InformationMetricsTests
{

private Img< IntType > imgZeros;
private Img< IntType > img3;
private Img< IntType > img3Shifted;
private Img< IntType > imgTwo;

// mutual info of a set with itself
private double MI_id = 1.0986122886681096;

@Before
public void setup()
{
int[] a = new int[]{ 0, 1, 2, 0, 1, 2, 0, 1, 2};
img3 = ArrayImgs.ints( a, a.length );

int[] b = new int[]{ 0, 1, 2, 1, 2, 0, 2, 0, 1};
img3Shifted = ArrayImgs.ints( b, b.length );

imgZeros = ArrayImgs.ints( new int[ 9 ], 9 );

int[] c = new int[]{ 0, 1, 0, 1, 0, 1, 0, 1 };
imgTwo = ArrayImgs.ints( c, c.length );
}

@Test
public void testEntropy()
{
double entropyZeros = InformationMetrics.entropy( imgZeros, 0, 1, 2 );
double entropyCoinFlip = InformationMetrics.entropy( imgTwo, 0, 1, 2 );

/*
* These tests fail
*/
// assertEquals( 0.0, entropyZeros, 1e-6 );
// assertEquals( 1.0, entropyCoinFlip, 1e-6 );

// System.out.println( "entropy zeros : " + entropyZeros );
}

@Test
public void testMutualInformation()
{
double miAA = InformationMetrics.mutualInformation( img3, img3, 0, 2, 3 );
double nmiAA = InformationMetrics.normalizedMutualInformation( img3, img3, 0, 2, 3 );

double miAB = InformationMetrics.mutualInformation( img3, img3Shifted, 0, 2, 3 );
double nmiAB = InformationMetrics.normalizedMutualInformation( img3, img3Shifted, 0, 2, 3 );

double miBA = InformationMetrics.mutualInformation( img3Shifted, img3, 0, 2, 3 );
double nmiBA = InformationMetrics.normalizedMutualInformation( img3Shifted, img3, 0, 2, 3 );

double miBB = InformationMetrics.mutualInformation( img3Shifted, img3Shifted, 0, 2, 3 );

assertEquals( "self MI", MI_id, miAA, 1e-6 );
assertEquals( "self MI", MI_id, miBB, 1e-6 );

// assertEquals( "MI symmetry", miAA, miBA, 1e-6 );
//
// System.out.println( "mi:" );
// System.out.println( miAA );
// System.out.println( miAB );
// System.out.println( "nmi:" );
// System.out.println( nmiAA );
// System.out.println( nmiAB );
}

}