Skip to content

Commit be30e72

Browse files
author
Jack Dermody
committed
small fixes
1 parent 54bce96 commit be30e72

File tree

9 files changed

+143
-104
lines changed

9 files changed

+143
-104
lines changed

Diff for: BrightData/BrightData.xml

+3-4
Original file line numberDiff line numberDiff line change
@@ -2655,7 +2655,6 @@
26552655
<summary>
26562656
Converts the typed buffer to a buffer of objects
26572657
</summary>
2658-
<typeparam name="T"></typeparam>
26592658
<param name="buffer"></param>
26602659
<returns></returns>
26612660
</member>
@@ -8447,12 +8446,12 @@
84478446
</member>
84488447
<member name="P:BrightData.ITensor.Shape">
84498448
<summary>
8450-
Tensor shape - for a vector the array will have a single element, for a matrix it will be [columns, rows], a 3D tensor will be [columns, rows, depth] etc
8449+
Tensor shape - for a vector the array will have a single element, for a matrix it will be [columns, rows], a 3D tensor will be [columns, rows, depth] etc.
84518450
</summary>
84528451
</member>
84538452
<member name="T:BrightData.ITensor`1">
84548453
<summary>
8455-
Typed tensor interface - vector, matrix, 3D tensor etc
8454+
Typed tensor interface - vector, matrix, 3D tensor etc.
84568455
</summary>
84578456
</member>
84588457
<member name="M:BrightData.ITensor`1.Reshape">
@@ -8580,7 +8579,7 @@
85808579
</member>
85818580
<member name="T:BrightData.ITensorType`3">
85828581
<summary>
8583-
Typed tensor interface - vector, matrix, 3D tensor etc
8582+
Typed tensor interface - vector, matrix, 3D tensor etc.
85848583
</summary>
85858584
<typeparam name="T"></typeparam>
85868585
<typeparam name="TT"></typeparam>

Diff for: BrightData/Buffer/ReadOnly/Helper/BufferConcatenator.cs

-80
This file was deleted.

Diff for: BrightData/ExtensionMethods.Buffers.cs

-1
Original file line numberDiff line numberDiff line change
@@ -1065,7 +1065,6 @@ public static IReadOnlyBuffer ConvertUnmanagedTo(this IReadOnlyBuffer buffer, Ty
10651065
/// <summary>
10661066
/// Converts the typed buffer to a buffer of objects
10671067
/// </summary>
1068-
/// <typeparam name="T"></typeparam>
10691068
/// <param name="buffer"></param>
10701069
/// <returns></returns>
10711070
public static IReadOnlyBuffer<object> ToObjectBuffer(this IReadOnlyBuffer buffer)

Diff for: BrightData/Interfaces.LinearAlgebra.cs

+3-3
Original file line numberDiff line numberDiff line change
@@ -799,13 +799,13 @@ public interface ITensor : IDisposable, IHaveBrightDataContext
799799
uint TotalSize { get; }
800800

801801
/// <summary>
802-
/// Tensor shape - for a vector the array will have a single element, for a matrix it will be [columns, rows], a 3D tensor will be [columns, rows, depth] etc
802+
/// Tensor shape - for a vector the array will have a single element, for a matrix it will be [columns, rows], a 3D tensor will be [columns, rows, depth] etc.
803803
/// </summary>
804804
uint[] Shape { get; }
805805
}
806806

807807
/// <summary>
808-
/// Typed tensor interface - vector, matrix, 3D tensor etc
808+
/// Typed tensor interface - vector, matrix, 3D tensor etc.
809809
/// </summary>
810810
public interface ITensor<T> : ITensor, IReadOnlyTensor<T>, IHaveLinearAlgebraProvider<T>, IHaveTensorSegment<T>
811811
where T: unmanaged, IBinaryFloatingPointIeee754<T>, IMinMaxValue<T>
@@ -935,7 +935,7 @@ public interface ITensor<T> : ITensor, IReadOnlyTensor<T>, IHaveLinearAlgebraPro
935935
}
936936

937937
/// <summary>
938-
/// Typed tensor interface - vector, matrix, 3D tensor etc
938+
/// Typed tensor interface - vector, matrix, 3D tensor etc.
939939
/// </summary>
940940
/// <typeparam name="T"></typeparam>
941941
/// <typeparam name="TT"></typeparam>

Diff for: BrightData/LinearAlgebra/MutableMatrix.cs

+118-11
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
using System.Linq;
55
using System.Numerics;
66
using System.Runtime.CompilerServices;
7+
using System.Runtime.InteropServices;
78
using System.Threading.Tasks;
89
using BrightData.LinearAlgebra.ReadOnly;
910
using BrightData.LinearAlgebra.Segments;
@@ -24,8 +25,8 @@ namespace BrightData.LinearAlgebra
2425
/// <param name="columns">Number of columns</param>
2526
/// <param name="lap">Linear algebra provider</param>
2627
public class MutableMatrix<T, LAP>(INumericSegment<T> data, uint rows, uint columns, LAP lap) : MutableTensorBase<T, IReadOnlyMatrix<T>, IMatrix<T>, LAP>(data, lap), IMatrix<T>
27-
where T: unmanaged, IBinaryFloatingPointIeee754<T>, IMinMaxValue<T>
28-
where LAP: LinearAlgebraProvider<T>
28+
where T : unmanaged, IBinaryFloatingPointIeee754<T>, IMinMaxValue<T>
29+
where LAP : LinearAlgebraProvider<T>
2930
{
3031
/// <inheritdoc />
3132
public uint RowCount { get; private set; } = rows;
@@ -79,31 +80,31 @@ protected set
7980
/// <inheritdoc />
8081
public INumericSegment<T> GetRow(uint index)
8182
{
82-
if(index > RowCount)
83+
if (index > RowCount)
8384
throw new ArgumentOutOfRangeException(nameof(index), $"Number of rows is {RowCount} but index {index} was requested");
8485
return new MutableTensorSegmentWrapper<T>(Segment, index, RowCount, ColumnCount);
8586
}
8687

8788
/// <inheritdoc />
8889
public virtual INumericSegment<T> GetColumn(uint index)
8990
{
90-
if(index > ColumnCount)
91+
if (index > ColumnCount)
9192
throw new ArgumentOutOfRangeException(nameof(index), $"Number of columns is {ColumnCount} but index {index} was requested");
9293
return new MutableTensorSegmentWrapper<T>(Segment, index * RowCount, 1, RowCount);
9394
}
9495

9596
/// <inheritdoc />
9697
public virtual IReadOnlyNumericSegment<T> GetReadOnlyRow(uint index)
9798
{
98-
if(index > RowCount)
99+
if (index > RowCount)
99100
throw new ArgumentOutOfRangeException(nameof(index), $"Number of rows is {RowCount} but index {index} was requested");
100101
return new ReadOnlyTensorSegmentWrapper<T>(Segment, index, RowCount, ColumnCount);
101102
}
102103

103104
/// <inheritdoc />
104105
public virtual IReadOnlyNumericSegment<T> GetReadOnlyColumn(uint index)
105106
{
106-
if(index > ColumnCount)
107+
if (index > ColumnCount)
107108
throw new ArgumentOutOfRangeException(nameof(index), $"Number of columns is {ColumnCount} but index {index} was requested");
108109
return new ReadOnlyTensorSegmentWrapper<T>(Segment, index * RowCount, 1, RowCount);
109110
}
@@ -317,13 +318,14 @@ static unsafe IMatrix<T> MultiplyWithThisTransposed(LinearAlgebraProvider<T> lap
317318
fixed (T* matrixPtr = matrixSpan)
318319
fixed (T* otherPtr = otherSpan)
319320
fixed (T* retPtr = retSpan) {
320-
MatrixMultiplyChunked(matrixPtr, otherPtr, lda, rowCount, columnCount, retPtr);
321+
//MatrixMultiplyChunked(matrixPtr, otherPtr, lda, rowCount, columnCount, retPtr);
322+
MatrixMultiplyTiled2(matrixPtr, otherPtr, lda, rowCount, columnCount, retPtr);
321323
}
322324
}
323325
finally {
324-
if(wasMatrixTempUsed)
326+
if (wasMatrixTempUsed)
325327
matrixTemp.Dispose();
326-
if(wasOtherTempUsed)
328+
if (wasOtherTempUsed)
327329
otherTemp.Dispose();
328330
}
329331

@@ -348,9 +350,9 @@ static unsafe void MatrixMultiplyChunked(T* a, T* b, int size, uint rows, uint c
348350

349351
return;
350352

351-
[MethodImpl(MethodImplOptions.AggressiveInlining)]void Multiply(long startIndex)
353+
[MethodImpl(MethodImplOptions.AggressiveInlining)] void Multiply(long startIndex)
352354
{
353-
for(long index = startIndex, len = Math.Min(startIndex + ChunkSize, totalSize); index < len; index++) {
355+
for (long index = startIndex, len = Math.Min(startIndex + ChunkSize, totalSize); index < len; index++) {
354356
var i = (uint)(index % rows);
355357
var j = (uint)(index / rows);
356358

@@ -371,6 +373,111 @@ static unsafe void MatrixMultiplyChunked(T* a, T* b, int size, uint rows, uint c
371373
}
372374
}
373375

376+
[MethodImpl(MethodImplOptions.AggressiveInlining)]
377+
static unsafe void MatrixMultiplyTiled(T* a, T* b, int size, uint rows, uint cols, T* ret)
378+
{
379+
const int TileSize = 32; // Size of the tile, should be adjusted based on hardware cache sizes.
380+
var vectorSize = Vector<T>.Count;
381+
var numVectors = size / vectorSize;
382+
var ceiling = numVectors * vectorSize;
383+
var totalSize = rows * cols;
384+
385+
[MethodImpl(MethodImplOptions.AggressiveInlining)]
386+
void MultiplyTile(uint rowStart, uint colStart)
387+
{
388+
for (uint i = rowStart; i < rowStart + TileSize && i < rows; i++) {
389+
for (uint j = colStart; j < colStart + TileSize && j < cols; j++) {
390+
var xPtr = &a[i * size];
391+
var xSpan = new ReadOnlySpan<T>(xPtr, size);
392+
var xVectors = MemoryMarshal.Cast<T, Vector<T>>(xSpan);
393+
394+
var yPtr = &b[j * size];
395+
var ySpan = new ReadOnlySpan<T>(yPtr, size);
396+
var yVectors = MemoryMarshal.Cast<T, Vector<T>>(ySpan);
397+
398+
var vSum = Vector<T>.Zero;
399+
for (var z = 0; z < numVectors; z++)
400+
vSum += xVectors[z] * yVectors[z];
401+
402+
var sum = Vector.Dot(vSum, Vector<T>.One);
403+
for (var z = ceiling; z < size; z++)
404+
sum += xPtr[z] * yPtr[z];
405+
ret[j * rows + i] = sum;
406+
}
407+
}
408+
}
409+
410+
if (totalSize >= Consts.MinimumSizeForParallel) {
411+
Parallel.For(0, (int)Math.Ceiling((double)rows / TileSize), rowTile => {
412+
for (uint colTile = 0; colTile < cols; colTile += TileSize) {
413+
MultiplyTile((uint)rowTile * TileSize, colTile);
414+
}
415+
});
416+
}
417+
else {
418+
for (uint rowTile = 0; rowTile < rows; rowTile += TileSize) {
419+
for (uint colTile = 0; colTile < cols; colTile += TileSize) {
420+
MultiplyTile(rowTile, colTile);
421+
}
422+
}
423+
}
424+
}
425+
426+
[MethodImpl(MethodImplOptions.AggressiveInlining)]
427+
static unsafe void MatrixMultiplyTiled2(T* a, T* b, int size, uint rows, uint cols, T* ret)
428+
{
429+
const int L1BlockSize = 32;
430+
const int L2BlockSize = 64;
431+
var vectorSize = Vector<T>.Count;
432+
var numVectors = size / vectorSize;
433+
var ceiling = numVectors * vectorSize;
434+
var totalSize = rows * cols;
435+
436+
[MethodImpl(MethodImplOptions.AggressiveInlining)]
437+
void MultiplyBlock(uint rowStart, uint colStart, uint rowEnd, uint colEnd)
438+
{
439+
for (uint i = rowStart; i < rowEnd && i < rows; i += L1BlockSize) {
440+
for (uint j = colStart; j < colEnd && j < cols; j += L1BlockSize) {
441+
for (uint ii = i; ii < i + L1BlockSize && ii < rowEnd && ii < rows; ii++) {
442+
for (uint jj = j; jj < j + L1BlockSize && jj < colEnd && jj < cols; jj++) {
443+
var xPtr = &a[ii * size];
444+
var xSpan = new ReadOnlySpan<T>(xPtr, size);
445+
var xVectors = MemoryMarshal.Cast<T, Vector<T>>(xSpan);
446+
447+
var yPtr = &b[jj * size];
448+
var ySpan = new ReadOnlySpan<T>(yPtr, size);
449+
var yVectors = MemoryMarshal.Cast<T, Vector<T>>(ySpan);
450+
451+
var vSum = Vector<T>.Zero;
452+
for (var z = 0; z < numVectors; z++)
453+
vSum += xVectors[z] * yVectors[z];
454+
455+
var sum = Vector.Dot(vSum, Vector<T>.One);
456+
for (var z = ceiling; z < size; z++)
457+
sum += xPtr[z] * yPtr[z];
458+
ret[jj * rows + ii] = sum;
459+
}
460+
}
461+
}
462+
}
463+
}
464+
465+
if (totalSize >= Consts.MinimumSizeForParallel) {
466+
Parallel.For(0, (int)Math.Ceiling((double)rows / L2BlockSize), rowTile => {
467+
for (uint colTile = 0; colTile < cols; colTile += L2BlockSize) {
468+
MultiplyBlock((uint)rowTile * L2BlockSize, colTile, (uint)((rowTile + 1) * L2BlockSize), colTile + L2BlockSize);
469+
}
470+
});
471+
}
472+
else {
473+
for (uint rowTile = 0; rowTile < rows; rowTile += L2BlockSize) {
474+
for (uint colTile = 0; colTile < cols; colTile += L2BlockSize) {
475+
MultiplyBlock(rowTile, colTile, rowTile + L2BlockSize, colTile + L2BlockSize);
476+
}
477+
}
478+
}
479+
}
480+
374481
/// <inheritdoc />
375482
public override string ToString()
376483
{

Diff for: BrightWire/BrightWire.xml

+10
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Diff for: BrightWire/ExecutionGraph/Node/NodeBase.cs

+7
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,14 @@ public abstract class NodeBase : ICanInitialiseNode, IDisposable, ICanSerialise
2020
string? _name;
2121
List<WireToNode> _output = [];
2222

23+
/// <summary>
24+
/// Callback method when the node has executed
25+
/// </summary>
2326
public delegate void ForwardDelegate(NodeBase previous, NodeBase current, IGraphData input, IGraphData? output);
27+
28+
/// <summary>
29+
/// Called when the node is executed
30+
/// </summary>
2431
public event ForwardDelegate? OnForward;
2532

2633
/// <summary>

Diff for: BrightWire/Models/StringTable.cs

+1-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
1-
using System;
2-
3-
namespace BrightWire.Models
1+
namespace BrightWire.Models
42
{
53
/// <summary>
64
/// An array of indexed strings

Diff for: ExampleCode/Program.cs

+1-2
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
using BrightData.Cuda;
99
using BrightData.LinearAlgebra;
1010
using BrightData.MKL;
11-
using BrightData.Parquet;
1211
using BrightWire;
1312
using ExampleCode.DataSet;
1413
using ExampleCode.DataTableTrainers;
@@ -100,7 +99,7 @@ static async Task IrisClassification(BrightDataContext context, bool useMkl)
10099
Start(context, useMkl);
101100
using var iris = await context.Iris();
102101
await iris.TrainNaiveBayes();
103-
iris.TrainDecisionTree();
102+
await iris.TrainDecisionTree();
104103
await iris.TrainRandomForest(500, 7);
105104
await iris.TrainKNearestNeighbours(10);
106105
//iris.TrainMultinomialLogisticRegression(500, 0.3f, 0.1f);

0 commit comments

Comments
 (0)