-
Notifications
You must be signed in to change notification settings - Fork 229
/
Copy pathLoading.swift
94 lines (83 loc) · 3.01 KB
/
Loading.swift
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
//
// Loading.swift
// Diffusion
//
// Created by Pedro Cuenca on December 2022.
// See LICENSE at https://github.com/huggingface/swift-coreml-diffusers/LICENSE
//
import SwiftUI
import Combine
let model = BENCHMARK ? ModelInfo.xlmbpChunked : deviceSupportsQuantization ? ModelInfo.v21Palettized : ModelInfo.v21Base
struct LoadingView: View {
@StateObject var generation = GenerationContext()
@State private var preparationPhase = "Downloading…"
@State private var downloadProgress: Double = 0
enum CurrentView {
case loading
case textToImage
case error(String)
}
@State private var currentView: CurrentView = .loading
@State private var stateSubscriber: Cancellable?
var body: some View {
VStack {
switch currentView {
case .textToImage: TextToImage().transition(.opacity)
case .error(let message): ErrorPopover(errorMessage: message).transition(.move(edge: .top))
case .loading:
// TODO: Don't present progress view if the pipeline is cached
ProgressView(preparationPhase, value: downloadProgress, total: 1).padding()
}
}
.animation(.easeIn, value: currentView)
.environmentObject(generation)
.onAppear {
Task.init {
let loader = PipelineLoader(model: model)
stateSubscriber = loader.statePublisher.sink { state in
DispatchQueue.main.async {
switch state {
case .downloading(let progress):
preparationPhase = "Downloading"
downloadProgress = progress
case .uncompressing:
preparationPhase = "Uncompressing"
downloadProgress = 1
case .readyOnDisk:
preparationPhase = "Loading"
downloadProgress = 1
default:
break
}
}
}
do {
generation.pipeline = try await loader.prepare()
print("Did load model \(loader.model)")
self.currentView = .textToImage
} catch {
self.currentView = .error("Could not load model, error: \(error)")
}
}
}
}
}
// Required by .animation
extension LoadingView.CurrentView: Equatable {}
struct ErrorPopover: View {
var errorMessage: String
var body: some View {
Text(errorMessage)
.font(.headline)
.padding()
.foregroundColor(.red)
.background(Color.white)
.cornerRadius(8)
.shadow(color: Color.black.opacity(0.2), radius: 8, x: 0, y: 4)
}
}
struct LoadingView_Previews: PreviewProvider {
static var previews: some View {
LoadingView()
}
}