diff --git a/HelloWorldApp/app/build.gradle b/HelloWorldApp/app/build.gradle index d385d700..6c0976d9 100644 --- a/HelloWorldApp/app/build.gradle +++ b/HelloWorldApp/app/build.gradle @@ -23,6 +23,6 @@ android { dependencies { implementation 'androidx.appcompat:appcompat:1.1.0' - implementation 'org.pytorch:pytorch_android:1.4.0' - implementation 'org.pytorch:pytorch_android_torchvision:1.4.0' + implementation 'org.pytorch:pytorch_android:1.6.0' + implementation 'org.pytorch:pytorch_android_torchvision:1.6.0' } diff --git a/HelloWorldApp/app/src/main/assets/model.pt b/HelloWorldApp/app/src/main/assets/model.pt index 1cea5f5c..d68d6ed4 100644 Binary files a/HelloWorldApp/app/src/main/assets/model.pt and b/HelloWorldApp/app/src/main/assets/model.pt differ diff --git a/HelloWorldApp/trace_model.py b/HelloWorldApp/trace_model.py index ccaeb013..a46798ce 100644 --- a/HelloWorldApp/trace_model.py +++ b/HelloWorldApp/trace_model.py @@ -1,8 +1,10 @@ import torch import torchvision +from torch.utils.mobile_optimizer import optimize_for_mobile model = torchvision.models.resnet18(pretrained=True) model.eval() example = torch.rand(1, 3, 224, 224) traced_script_module = torch.jit.trace(model, example) -traced_script_module.save("app/src/main/assets/model.pt") \ No newline at end of file +torchscript_model_optimized = optimize_for_mobile(traced_script_module) +torchscript_model_optimized.save("app/src/main/assets/model.pt")