![image](https://private-user-images.githubusercontent.com/49572294/345605747-f28300e9-44cb-4d1a-8acf-8a682230be31.png?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3MzkyNDQyMjYsIm5iZiI6MTczOTI0MzkyNiwicGF0aCI6Ii80OTU3MjI5NC8zNDU2MDU3NDctZjI4MzAwZTktNDRjYi00ZDFhLThhY2YtOGE2ODIyMzBiZTMxLnBuZz9YLUFtei1BbGdvcml0aG09QVdTNC1ITUFDLVNIQTI1NiZYLUFtei1DcmVkZW50aWFsPUFLSUFWQ09EWUxTQTUzUFFLNFpBJTJGMjAyNTAyMTElMkZ1cy1lYXN0LTElMkZzMyUyRmF3czRfcmVxdWVzdCZYLUFtei1EYXRlPTIwMjUwMjExVDAzMTg0NlomWC1BbXotRXhwaXJlcz0zMDAmWC1BbXotU2lnbmF0dXJlPWY2NjBkODliMzAzZjU4NjczZjY3MzczYzU1NjViZjFhMDA3MDIwYzA4ZWRkOGY2NDU3NTJlODNmYWM5YTgyMWMmWC1BbXotU2lnbmVkSGVhZGVycz1ob3N0In0.75CyVsfzGwsQB-bpvOBPzdcuctA3YqpxsN_iGEdxExo)
An interface to the Flux deep learning models for the MLJ machine learning framework
Branch | Julia | CPU CI | GPU CI | Coverage |
---|---|---|---|---|
master |
v1 | |||
dev |
v1 |
using MLJ, MLJFlux, RDatasets, Plots
Grab some data and split into features and target:
iris = RDatasets.dataset("datasets", "iris");
y, X = unpack(iris, ==(:Species), rng=123);
X = Float32.(X); # To optmise for GPUs
Load model code and instantiate an MLJFlux model:
NeuralNetworkClassifier = @load NeuralNetworkClassifier pkg=MLJFlux
clf = NeuralNetworkClassifier(
builder=MLJFlux.MLP(; hidden=(5,4)),
batch_size=8,
epochs=50,
acceleration=CUDALibs() # for training on a GPU
)
Wrap in "iteration controls":
stop_conditions = [
Step(1), # Apply controls every epoch
NumberLimit(1000), # Don't train for more than 1000 steps
Patience(4), # Stop after 4 iterations of deteriation in validation loss
NumberSinceBest(5), # Or if the best loss occurred 5 iterations ago
TimeLimit(30/60), # Or if 30 minutes has passed
]
validation_losses = []
train_losses = []
callbacks = [
WithLossDo(loss->push!(validation_losses, loss)),
WithTrainingLossesDo(losses->push!(train_losses, losses[end])),
]
iterated_model = IteratedModel(
model=clf,
resampling=Holdout(fraction_train=0.5); # loss and stopping are based on out-of-sample
measures=log_loss,
controls=vcat(stop_conditions, callbacks),
);
Train the wrapped model:
julia> mach = machine(iterated_model, X, y)
julia> fit!(mach)
[ Info: No iteration parameter specified. Using `iteration_parameter=:(epochs)`.
[ Info: final loss: 0.1284184007796247
[ Info: final training loss: 0.055630706
[ Info: Stop triggered by NumberSinceBest(5) stopping criterion.
[ Info: Total of 811 iterations.
Inspect results:
julia> plot(train_losses, label="Training Loss")
julia> plot!(validation_losses, label="Validation Loss", linewidth=2, size=(800,400))