Skip to content

Commit a2410c8

Browse files
authored
Merge pull request #604 from valhassan/601-maintenance-manage-geo-deep-learning-with-uv-by-astral
601 maintenance manage geo deep learning with uv by astral
2 parents 1e397c9 + 92ebd31 commit a2410c8

File tree

11 files changed

+149
-69
lines changed

11 files changed

+149
-69
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22
*.idea**
33
*.vscode**
44

5+
# Distribution / packaging
6+
*.egg-info/
7+
58
# Specific folders name
69
waterloo_subset_512/
710
mlruns/

README.md

Lines changed: 70 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,18 @@ A PyTorch Lightning-based framework for geospatial deep learning with multi-sens
44

55
## Overview
66

7-
Geo Deep Learning (GDL) is a modular framework designed for semantic segmentation of geospatial imagery using state-of-the-art deep learning models. Built on PyTorch Lightning, it provides efficient training pipelines for multi-sensor data with WebDataset support.
7+
Geo Deep Learning (GDL) is a modular framework designed to support a wide range of geospatial deep learning tasks such as semantic segmentation, object detection, and regression.
8+
Built on PyTorch Lightning, it provides efficient training pipelines for multi-sensor data.
89

910
## Features
1011

11-
- **Multi-sensor Support**: Handle multiple Earth observation sensors simultaneously
12-
- **Modular Architecture**: Encoder-neck-decoder pattern with interchangeable components
13-
- **WebDataset Integration**: Efficient large-scale data loading and processing
14-
- **Multiple Model Types**: UNet++, SegFormer, DOFA (Dynamic-one-for-all Architecture)
15-
- **Distributed Training**: Multi-GPU training with DDP strategy
16-
- **MLflow Logging**: Comprehensive experiment tracking and model versioning
17-
- **Flexible Data Pipeline**: Support for CSV and WebDataset formats
12+
- **Multi-sensor Support**: Handle multiple Earth observation sensors simultaneously.
13+
- **Modular Architecture**: Encoder-neck-decoder pattern with interchangeable components.
14+
- **WebDataset Integration**: Efficient large-scale data loading and processing.
15+
- **Multiple Model Types**: UNet++, SegFormer, DOFA (Dynamic-one-for-all Architecture).
16+
- **Distributed Training**: Multi-GPU training with supported strategies.
17+
- **MLflow Logging**: Comprehensive experiment tracking and model versioning.
18+
- **Flexible Data Pipeline**: Support for CSV and WebDataset formats.
1819

1920
## Architecture
2021

@@ -31,23 +32,47 @@ Geo Deep Learning (GDL) is a modular framework designed for semantic segmentatio
3132
└── samplers/ # Custom data sampling strategies
3233
```
3334

35+
## Requirements
36+
- Install [uv](https://docs.astral.sh/uv/) package manager for your OS.
37+
3438
## Quick Start
3539

40+
1. **Clone the repository:**
3641
```bash
37-
git clone <repository-url>
42+
git clone https://github.com/NRCan/geo-deep-learning.git
3843
cd geo-deep-learning
3944
```
45+
2. **Install dependencies:**
4046

41-
### Training
47+
For **GPU training** with CUDA 12.8:
48+
```bash
49+
uv sync --extra cu128
50+
```
4251

52+
For **CPU-only** training:
4353
```bash
44-
# Single GPU training
45-
python geo_deep_learning/train.py fit --config configs/dofa_config_RGB.yaml
54+
uv sync --extra cpu
4655
```
56+
This creates a virtual environment in `.venv/` and installs all dependencies.
57+
58+
3. **Activate the environment:**
59+
```bash
60+
# Linux/macOS
61+
source .venv/bin/activate
62+
63+
# Windows
64+
.venv\Scripts\activate
65+
```
66+
67+
Or use `uv run` to execute commands without manual activation:
68+
```bash
69+
uv run python geo_deep_learning/train.py fit --config configs/dofa_config_RGB.yaml
70+
```
71+
**Note:** *If you prefer to use conda or another environment manager, you can generate a `requirements.txt` file from the dependencies listed in `pyproject.toml` for manual installation.*
4772

4873
### Configuration
4974

50-
Models are configured via YAML files in `configs/`:
75+
Models are configured via YAML files in the `configs/` directory:
5176

5277
```yaml
5378
model:
@@ -65,54 +90,53 @@ data:
6590
sensor_configs_path: "path/to/sensor_configs.yaml"
6691
batch_size: 16
6792
patch_size: [512, 512]
93+
94+
trainer:
95+
max_epochs: 100
96+
precision: 16-mixed
97+
accelerator: gpu
98+
devices: 1
6899
```
69100
70101
## Supported Models
71102
72-
### DOFA (Domain-Oriented Foundation Architecture)
73-
- **DOFA Base**: 768-dim embeddings, suitable for most tasks
74-
- **DOFA Large**: 1024-dim embeddings, higher capacity
75-
- Multi-scale feature extraction with UperNet decoder
76-
- Support for wavelength-specific processing
77-
78103
### UNet++
79-
- Classic U-Net architecture with dense skip connections
80-
- Multiple encoder backbones (ResNet, EfficientNet, etc.)
81-
- Optimized for medical and satellite imagery
104+
- Classic U-Net architecture with dense skip connections.
105+
- Multiple encoder backbones (ResNet, EfficientNet, etc.).
106+
- Available through segmentation-models-pytorch.
82107
83108
### SegFormer
84-
- Transformer-based architecture for semantic segmentation
85-
- Hierarchical feature representation
86-
- Efficient attention mechanisms
109+
- Transformer-based architecture for semantic segmentation.
110+
- Hierarchical feature representation (MixTransformer encoder).
111+
- Multiple model sizes (B0-B5).
112+
113+
### DOFA (Dynamic One-For-All foundation model)
114+
- **DOFA Base**: 768-dim embeddings, suitable for most tasks.
115+
- **DOFA Large**: 1024-dim embeddings, higher capacity.
116+
- Multi-scale feature extraction with UperNet decoder.
117+
- Support for wavelength-specific processing.
118+
87119
88120
## Data Pipeline
89121
90122
### Multi-Sensor DataModule
91-
- **Sensor Mixing**: Combine data from multiple sensors during training
92-
- **WebDataset Format**: Efficient sharded data storage and loading
93-
- **Patch-based Processing**: Configurable patch sizes (default: 512x512)
94-
- **Data Augmentation**: Built-in augmentation pipeline
123+
- **Sensor Mixing**: Combine data from multiple sensors during training.
124+
- **WebDataset Format**: Efficient sharded data storage and loading.
95125
96126
### Supported Data Formats
97-
- **WebDataset**: Sharded tar files with metadata
98-
- **CSV**: Traditional CSV with file paths and labels
99-
- **Multi-sensor**: YAML configuration for sensor-specific settings
127+
- **WebDataset**: Sharded tar files with metadata.
128+
- **CSV**: Traditional CSV with file paths and labels.
129+
- **Multi-sensor**: YAML configuration for sensor-specific settings.
100130
101131
## Training Features
102-
103-
- **Mixed Precision**: 16-bit mixed precision training
104-
- **Gradient Clipping**: Configurable gradient clipping
105-
- **Early Stopping**: Automatic training termination
106-
- **Model Checkpointing**: Best model saving based on validation metrics
107-
- **Visualization**: Built-in prediction visualization callbacks
108-
109-
## Distributed Training
110-
111-
The framework supports multi-GPU training with:
112-
- DDP (Distributed Data Parallel) strategy
113-
- Automatic mixed precision
114-
- Synchronized batch normalization
115-
- Efficient NCCL communication
132+
- **Large-scale training**: Distributed training strategies enabled with pytorch lightning.
133+
- **Mixed Precision Training**: 16-bit mixed precision for faster training.
134+
- **Gradient Clipping**: Configurable gradient clipping for stability.
135+
- **Early Stopping**: Automatic training termination based on validation metrics.
136+
- **Model Checkpointing**: Saves best models based on validation performance.
137+
- **MLflow Integration**: Experiment tracking, metrics logging, and model registry.
138+
- **Visualization Callbacks**: Built-in prediction visualization during training.
139+
- **Learning Rate Scheduling**: Cosine annealing, step decay, and more.
116140
117141
## Development
118142
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Logging configuration."""
File renamed without changes.
File renamed without changes.

geo_deep_learning/models/encoders/mix_transformer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,12 @@
77

88
import torch
99
import torch.nn.functional as fn
10-
from models.segmentation.base import EncoderMixin
1110
from timm.layers import DropPath, to_2tuple, trunc_normal_
1211
from torch import Tensor, nn
1312
from torch.utils import model_zoo
1413

14+
from geo_deep_learning.models.segmentation.base import EncoderMixin
15+
1516

1617
class Mlp(nn.Module):
1718
"""MLP module."""

geo_deep_learning/models/segmentation/segformer.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,12 @@
22

33
import torch
44
import torch.nn.functional as fn
5-
from models.decoders.segformer_mlp import Decoder
6-
from models.encoders.mix_transformer import DynamicMixTransformer, get_encoder
5+
6+
from geo_deep_learning.models.decoders.segformer_mlp import Decoder
7+
from geo_deep_learning.models.encoders.mix_transformer import (
8+
DynamicMixTransformer,
9+
get_encoder,
10+
)
711

812
from .base import BaseSegmentationModel
913

geo_deep_learning/tasks_with_models/segmentation_dofa.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@
1616
from torchmetrics.segmentation import MeanIoU
1717
from torchmetrics.wrappers import ClasswiseWrapper
1818

19+
from geo_deep_learning.models.segmentation.dofa import DOFASegmentationModel
20+
from geo_deep_learning.tools.visualization import visualize_prediction
1921
from geo_deep_learning.utils.models import load_weights_from_checkpoint
2022
from geo_deep_learning.utils.tensors import denormalization
21-
from models.segmentation.dofa import DOFASegmentationModel
22-
from tools.visualization import visualize_prediction
2323

2424
# Ignore warning about default grid_sample and affine_grid behavior triggered by kornia
2525
warnings.filterwarnings(

geo_deep_learning/tasks_with_models/segmentation_segformer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@
1616
from torchmetrics.segmentation import MeanIoU
1717
from torchmetrics.wrappers import ClasswiseWrapper
1818

19+
from geo_deep_learning.models.segmentation.segformer import SegFormerSegmentationModel
20+
from geo_deep_learning.tools.visualization import visualize_prediction
1921
from geo_deep_learning.utils.models import load_weights_from_checkpoint
2022
from geo_deep_learning.utils.tensors import denormalization
21-
from models.segmentation.segformer import SegFormerSegmentationModel
22-
from tools.visualization import visualize_prediction
2323

2424
warnings.filterwarnings(
2525
"ignore",

geo_deep_learning/train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from lightning.pytorch.cli import ArgsType, LightningCLI
88
from lightning.pytorch.loggers import MLFlowLogger
99

10-
from configs import logging_config # noqa: F401
10+
from geo_deep_learning.config import logging_config # noqa: F401
1111
from geo_deep_learning.tools.mlflow_logger import LoggerSaveConfigCallback
1212

1313
logger = logging.getLogger(__name__)

0 commit comments

Comments
 (0)