Skip to content

Latest commit

Β 

History

History
568 lines (417 loc) Β· 16.9 KB

File metadata and controls

568 lines (417 loc) Β· 16.9 KB

TyTorch Development Roadmap

Project Status

Phase 1: Core Refactoring βœ… COMPLETE

  • All 24 tensor operations extracted to modular files
  • Test suite reorganized (unit/cpu/mps)
  • 470 tests passing across 32 test files
  • Clean, maintainable architecture established

Phase 2: Essential ML Methods 🚧 IN PROGRESS

  • Phase 2A (Shape Operations): βœ… COMPLETE - All 6 shape operations implemented
  • Phase 2B (Autograd & Activations): βœ… COMPLETE - Autograd βœ… COMPLETE (all 6 utilities), Activations: relu() βœ…, sigmoid() βœ…, tanh() βœ…, softmax() βœ…, log_softmax() βœ… complete
  • Phase 2C (Loss Functions): βœ… COMPLETE - mse_loss() βœ…, cross_entropy() βœ…, nll_loss() βœ…, binary_cross_entropy() βœ… complete
  • Focus on methods needed for real machine learning work

βœ… Completed Methods (46)

Arithmetic Operations (8)

  • βœ… add() - Element-wise addition (tensor + tensor, tensor + scalar)
  • βœ… add_() - In-place element-wise addition
  • βœ… sub() - Element-wise subtraction
  • βœ… sub_() - In-place element-wise subtraction
  • βœ… mul() - Element-wise multiplication
  • βœ… mul_() - In-place element-wise multiplication
  • βœ… div() - Element-wise division
  • βœ… div_() - In-place element-wise division

Matrix Operations (1)

  • βœ… matmul() - Matrix multiplication (dot products, matrix-vector, matrix-matrix)

Reduction Operations (2)

  • βœ… sum() - Sum all elements to scalar
  • βœ… mean() - Arithmetic mean of all elements

Device Management (4)

  • βœ… cpu() - Move tensor to CPU
  • βœ… cuda() - Move tensor to CUDA (with availability check)
  • βœ… mps() - Move tensor to Apple Silicon MPS
  • βœ… to() - Generic device/dtype conversion

Dtype Conversions (4)

  • βœ… float() - Convert to float32
  • βœ… double() - Convert to float64
  • βœ… int() - Convert to int32
  • βœ… long() - Convert to int64

Property Accessors (3)

  • βœ… shape - Get tensor shape
  • βœ… dtype - Get tensor dtype
  • βœ… device - Get tensor device

Utility Methods (2)

  • βœ… toString() - Convert to PyTorch string format
  • βœ… toArray() - Convert to JavaScript array

Shape Operations (6)

  • βœ… reshape() - Reshape tensor to new shape without copying data
  • βœ… flatten() - Flatten tensor to 1D or flatten specific dimensions
  • βœ… unsqueeze() - Add dimension of size 1 at specified position
  • βœ… squeeze() - Remove dimensions of size 1
  • βœ… transpose() - Swap two dimensions of the tensor
  • βœ… permute() - Reorder dimensions according to specified order

Autograd Operations (6)

  • βœ… requires_grad - Property to mark tensors for gradient tracking
  • βœ… backward() - Compute gradients via backpropagation
  • βœ… grad - Property to access computed gradients
  • βœ… zero_grad() - Clear gradients (sets grad to None)
  • βœ… detach() - Detach tensor from computation graph
  • βœ… no_grad() - Context function to disable gradient tracking

Activation Functions (5)

  • βœ… relu() - ReLU activation function (max(0, x))
  • βœ… sigmoid() - Sigmoid activation function (Οƒ(x) = 1 / (1 + exp(-x)))
  • βœ… tanh() - Hyperbolic tangent activation (tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x)))
  • βœ… softmax() - Softmax activation with optional dim parameter (for classification)
  • βœ… log_softmax() - Log-softmax activation (numerically stable log(softmax(x)))

Loss Functions (4)

  • βœ… mse_loss() - Mean Squared Error loss (for regression)
  • βœ… cross_entropy() - Cross Entropy loss (for classification)
  • βœ… nll_loss() - Negative Log-Likelihood loss (for classification with log-probabilities)
  • βœ… binary_cross_entropy() - Binary Cross Entropy loss (for binary classification)

🎯 High Priority: Essential ML Methods

These are the most critical methods needed to build, train, and use real ML models.

1. Tensor Shape Operations βœ… COMPLETE

  • reshape() / view() - Reshape without copying (e.g., flatten for FC layers)
  • unsqueeze() - Add dimension (e.g., add batch dimension)
  • squeeze() - Remove singleton dimensions
  • transpose() - Swap two dimensions (e.g., for convolutions)
  • permute() - Arbitrary dimension reordering
  • flatten() - Flatten to 1D (convenience method)

Why: Every ML model needs to manipulate tensor shapes for layers, batches, etc.

2. Autograd / Gradient Computation (CRITICAL) βœ… COMPLETE

  • requires_grad property - Mark tensors for gradient tracking
  • backward() - Compute gradients via backpropagation
  • grad property - Access computed gradients
  • zero_grad() - Clear gradients (sets grad to None)
  • detach() - Detach from computation graph
  • no_grad() context - Disable gradient tracking

Why: This is THE foundation of training. No autograd = no learning!

3. Activation Functions (CRITICAL) βœ… COMPLETE

  • relu() - ReLU activation (most common)
  • sigmoid() - Sigmoid activation (Οƒ(x) = 1 / (1 + exp(-x)))
  • tanh() - Hyperbolic tangent
  • softmax() - Softmax (for classification)
  • log_softmax() - Log-softmax (numerically stable)

Why: Essential for any neural network layer.

4. Loss Functions (CRITICAL) βœ… COMPLETE

  • mse_loss() - Mean squared error (regression)
  • cross_entropy() - Cross-entropy loss (classification)
  • nll_loss() - Negative log-likelihood
  • binary_cross_entropy() - Binary classification loss

Why: Can't train without computing loss!

5. Indexing & Slicing (HIGH)

  • slice() - Slice tensor along dimension
  • index_select() - Select specific indices
  • [] operator - Bracket indexing (e.g., tensor[0:2, :])
  • narrow() - Narrow a dimension

Why: Access batches, specific samples, extract predictions.

6. Concatenation & Stacking (HIGH)

  • cat() / concat() - Concatenate along existing dimension
  • stack() - Stack tensors along new dimension
  • split() - Split tensor into chunks
  • chunk() - Split into equal-sized chunks

Why: Combine data batches, build datasets, split for mini-batches.

7. Advanced Reductions (HIGH)

  • max() - Maximum value (with optional dimension)
  • min() - Minimum value
  • argmax() - Index of maximum value
  • argmin() - Index of minimum value
  • sum() with dimension - Sum along specific dimension (already have global sum)
  • mean() with dimension - Mean along specific dimension

Why: Model predictions, accuracy computation, batch statistics.

8. Element-wise Math Operations (MEDIUM)

  • pow() - Raise to power
  • sqrt() - Square root
  • exp() - Exponential
  • log() - Natural logarithm
  • abs() - Absolute value
  • neg() - Negate (multiply by -1)
  • clamp() - Clamp values to range

Why: Custom loss functions, normalization, numerical stability.

9. Comparison Operations (MEDIUM)

  • eq() - Element-wise equality
  • ne() - Element-wise not-equal
  • gt() / ge() - Greater than / greater-or-equal
  • lt() / le() - Less than / less-or-equal
  • all() - Check if all elements are true
  • any() - Check if any element is true

Why: Masking, filtering, conditional operations.

10. Tensor Creation Utilities (MEDIUM)

  • clone() - Deep copy tensor
  • randint() - Random integers (for labels, indices)
  • arange() - Range of values
  • linspace() - Linearly spaced values
  • eye() - Identity matrix
  • full() - Fill with specific value

Why: Data generation, initialization, testing.


πŸ“‹ Future Methods (Lower Priority)

These methods are useful but not critical for basic ML workflows.

Tensor Manipulation

  • expand() - Expand tensor to larger size
  • repeat() - Repeat tensor along dimensions
  • tile() - Tile tensor
  • gather() - Gather values along dimension
  • scatter() - Scatter values

Advanced Operations

  • Convolution operations (conv1d, conv2d, conv3d)
  • Pooling operations (max_pool2d, avg_pool2d)
  • Normalization (batch_norm, layer_norm)
  • Dropout operations
  • RNN/LSTM operations

Linear Algebra

  • inverse() - Matrix inverse
  • svd() - Singular value decomposition
  • eig() - Eigenvalues and eigenvectors
  • qr() - QR decomposition
  • cholesky() - Cholesky decomposition

Statistics

  • std() - Standard deviation
  • var() - Variance
  • median() - Median value
  • mode() - Most common value
  • histogram() - Compute histogram

Other Operations

  • where() - Conditional selection
  • masked_select() - Select with boolean mask
  • nonzero() - Indices of non-zero elements
  • unique() - Unique elements
  • sort() - Sort tensor values

πŸ—οΈ Implementation Strategy

Phase 2A: Shape Operations βœ… COMPLETE

Goal: Enable basic tensor shape manipulation

  1. βœ… Implemented reshape(), unsqueeze(), squeeze()
  2. βœ… Implemented flatten(), transpose(), permute()
  3. βœ… 117 new tests added (27 unit, 23 CPU, 24 MPS for permute; similar for others)
  4. βœ… All shape operations working across CPU and MPS devices

Next: Phase 2B - Autograd implementation

Phase 2B: Activations & Loss (Week 3)

Goal: Enable model training

  1. Implement activation functions (relu, sigmoid, tanh, softmax)
  2. Implement loss functions (mse_loss, cross_entropy)
  3. Complete autograd (backward(), gradient tracking)

Phase 2C: Data Operations (Week 4)

Goal: Enable real data loading and batching

  1. Implement indexing/slicing
  2. Implement cat(), stack(), split()
  3. Implement max(), min(), argmax(), argmin()

Phase 2D: Training Utilities (Week 5)

Goal: Complete training loop support

  1. Implement element-wise math (pow, sqrt, exp, log)
  2. Implement comparison operations
  3. Implement tensor creation utilities (clone, randint, etc.)

Phase 3: Advanced Features (Future)

  • Convolution and pooling layers
  • Normalization layers
  • Recurrent operations
  • Advanced linear algebra

🚨 Critical Issue: Missing Error Handling

Priority: HIGH - This is a serious bug that can crash Node.js

Problem Description

32 out of 47 C++ operations are missing try-catch blocks around PyTorch calls. This means:

  • Invalid inputs cause uncaught C++ exceptions
  • These exceptions terminate the Node.js process instead of throwing JavaScript errors
  • Users cannot catch these errors with try-catch in their code
  • The application crashes instead of handling errors gracefully

Example of the bug:

// This will CRASH the entire Node.js process:
const x = torch.tensor([1.0, 2.0, 3.0, 4.0]);
const y = x.reshape([2, 3]); // Invalid: 4 elements can't reshape to 6

// Error message:
// libc++abi: terminating due to uncaught exception of type std::runtime_error
// [Node.js process exits]

Expected behavior:

// Should throw a catchable JavaScript error:
try {
  const x = torch.tensor([1.0, 2.0, 3.0, 4.0]);
  const y = x.reshape([2, 3]);
} catch (e) {
  console.error('Caught error:', e.message);
  // Program continues...
}

Operations Missing Error Handling

Arithmetic Operations (8/8 missing):

  • add.cpp
  • add_.cpp
  • sub.cpp
  • sub_.cpp
  • mul.cpp
  • mul_.cpp
  • div.cpp
  • div_.cpp

Matrix Operations (1/1 missing):

  • matmul.cpp

Reduction Operations (2/2 missing):

  • sum.cpp
  • mean.cpp

Device Management (4/4 missing):

  • cpu.cpp
  • cuda.cpp
  • mps.cpp
  • to.cpp

Dtype Conversions (5/5 missing):

  • float.cpp
  • double.cpp
  • int.cpp
  • long.cpp
  • to.cpp (duplicate from device management)

Property Accessors (3/3 missing):

  • shape.cpp
  • dtype.cpp
  • device.cpp

Utility Methods (2/2 missing):

  • to_string.cpp
  • to_array.cpp

Shape Operations (6/6 missing):

  • reshape.cpp
  • flatten.cpp
  • unsqueeze.cpp
  • squeeze.cpp
  • transpose.cpp
  • permute.cpp

Autograd Operations (2/6 missing):

  • requires_grad.cpp
  • grad.cpp
  • backward.cpp βœ… Has try-catch
  • zero_grad.cpp βœ… Has try-catch
  • detach.cpp βœ… Has try-catch
  • no_grad.cpp βœ… Has try-catch

Activation Functions (0/5 missing):

  • relu.cpp βœ… Has try-catch
  • sigmoid.cpp βœ… Has try-catch
  • tanh.cpp βœ… Has try-catch
  • softmax.cpp βœ… Has try-catch
  • log_softmax.cpp βœ… Has try-catch

Loss Functions (0/4 missing):

  • mse_loss.cpp βœ… Has try-catch
  • cross_entropy.cpp βœ… Has try-catch
  • nll_loss.cpp βœ… Has try-catch
  • binary_cross_entropy.cpp βœ… Has try-catch

Fix Checklist

For each operation listed above:

  1. Open the .cpp file (e.g., src/native/ops/reshape.cpp)

  2. Wrap PyTorch calls in try-catch:

    // Before:
    torch::Tensor result = torch::reshape(self->tensor, new_shape);
    return Tensor::NewInstance(env, result);
    
    // After:
    try {
      torch::Tensor result = torch::reshape(self->tensor, new_shape);
      return Tensor::NewInstance(env, result);
    } catch (const std::exception& e) {
      Napi::Error::New(env, e.what()).ThrowAsJavaScriptException();
      return env.Undefined();
    }
  3. Test error handling:

    • Add test cases that intentionally cause errors
    • Verify errors are catchable with JavaScript try-catch
    • Verify error messages are helpful
  4. Build and test:

    pnpm build
    pnpm test

Implementation Strategy

Option A: Fix all at once

  • Create a single PR fixing all 32 files
  • Pros: Comprehensive fix, ensures consistency
  • Cons: Large changeset, harder to review

Option B: Fix by category

  • Fix one category at a time (e.g., all shape operations)
  • Pros: Easier to review, can prioritize high-risk operations
  • Cons: Takes longer, inconsistency during transition

Option C: Fix as we go

  • Fix operations when we add tests or modify them
  • Pros: Minimal disruption, natural progression
  • Cons: Bug persists in unfixed operations

Recommendation: Option B - Fix by category, starting with highest-risk operations:

  1. Shape operations (most likely to get invalid inputs)
  2. Arithmetic operations (commonly used)
  3. Device/dtype conversions
  4. Matrix operations
  5. Reductions and utilities
  6. Property accessors (lowest risk - rarely throw)

Testing Error Handling

For each fixed operation, add tests like:

describe('Error Handling', () => {
  it('should throw catchable error on invalid input', () => {
    expect(() => {
      const x = torch.tensor([1.0, 2.0, 3.0, 4.0]);
      x.reshape([2, 3]); // Invalid shape
    }).toThrow();
  });

  it('should allow error to be caught', () => {
    try {
      const x = torch.tensor([1.0, 2.0, 3.0, 4.0]);
      x.reshape([2, 3]);
      expect(true).toBe(false); // Should not reach here
    } catch (e) {
      expect(e.message).toContain('invalid');
    }
  });
});

Going Forward

All new operations MUST include try-catch blocks. This is now part of the standard implementation pattern.


πŸ“Š Success Metrics

Phase 2 Complete When:

  • βœ… Can define a simple feedforward neural network
  • βœ… Can perform forward pass with activations
  • βœ… Can compute loss
  • βœ… Can perform backward pass (compute gradients)
  • βœ… Can update weights (gradient descent step)
  • βœ… Can train on MNIST or similar dataset
  • βœ… Can achieve reasonable accuracy (>90% on MNIST)

πŸ§ͺ Testing Strategy

Test Organization

  • Unit Tests (test/unit/) - Device-agnostic core functionality
  • CPU Tests (test/cpu/) - CPU-specific behavior
  • MPS Tests (test/mps/) - Apple Silicon GPU tests

Test Scripts

  • pnpm test - Run all tests (unit + cpu + mps)
  • pnpm test:unit - Run unit tests only
  • pnpm test:cpu - Run CPU tests only
  • pnpm test:mps - Run MPS tests only

Coverage Goals

  • All methods should have comprehensive tests
  • Test edge cases (empty tensors, single elements, large tensors)
  • Test device transfers (CPU ↔ MPS)
  • Test gradient computation for autograd methods
  • Test numerical accuracy against PyTorch

πŸ“ Notes

Current Architecture

  • Clean modular structure: one C++ file per operation
  • Operations in src/native/ops/ directory
  • Consistent TensorOps namespace pattern
  • Comprehensive test coverage (430 tests passing across 42 test files)

Development Guidelines

  1. REQUIRED: Wrap all PyTorch C++ calls in try-catch blocks (see error handling section above)
  2. Follow established extraction pattern for new methods
  3. Add both CPU and MPS tests for each method
  4. Include error handling tests for invalid inputs
  5. Ensure all tests pass before committing
  6. Document complex operations with comments
  7. Match PyTorch's API as closely as possible

PyTorch Compatibility

  • Use PyTorch's libtorch C++ API as backend
  • Match PyTorch's method signatures where possible
  • Follow PyTorch's broadcasting rules
  • Maintain PyTorch's device model (CPU/CUDA/MPS)