From fcf166d8acf1f09dfd083ed72475e120be02bb78 Mon Sep 17 00:00:00 2001 From: Robin Tuszik <47579899+rtuszik@users.noreply.github.com> Date: Thu, 22 Aug 2024 22:35:14 +0200 Subject: [PATCH] refactor: restructure project and improve code quality - Move main GUI logic to src/gui.py - Create separate files for image generation and utilities - Improve error handling and type annotations - Enhance code readability and maintainability - Add Pillow dependency for image metadata handling --- requirements.txt | 1 + FluxLoraGUI.py => src/gui.py | 277 +++++++++++++---------------------- src/image_generator.py | 18 +++ src/main.py | 16 ++ src/utils.py | 74 ++++++++++ 5 files changed, 213 insertions(+), 173 deletions(-) rename FluxLoraGUI.py => src/gui.py (78%) create mode 100644 src/image_generator.py create mode 100644 src/main.py create mode 100644 src/utils.py diff --git a/requirements.txt b/requirements.txt index 24e7377..e77f435 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,3 +2,4 @@ PyQt6==6.7.1 replicate==0.31.0 python-dotenv==1.0.1 token_count==0.2.1 +pillow==10.4.0 diff --git a/FluxLoraGUI.py b/src/gui.py similarity index 78% rename from FluxLoraGUI.py rename to src/gui.py index 2bcead6..6796ab6 100644 --- a/FluxLoraGUI.py +++ b/src/gui.py @@ -1,111 +1,35 @@ import os import time from urllib.request import urlretrieve + +from PyQt6.QtCore import QSettings, Qt, QThreadPool, QTimer +from PyQt6.QtGui import QGuiApplication, QPixmap, QResizeEvent from PyQt6.QtWidgets import ( - QApplication, - QMainWindow, - QWidget, - QVBoxLayout, + QCheckBox, + QComboBox, + QDialog, + QDoubleSpinBox, + QFileDialog, + QFormLayout, + QGridLayout, QHBoxLayout, + QLabel, QLineEdit, + QMainWindow, + QMessageBox, + QProgressBar, QPushButton, - QLabel, QScrollArea, - QGridLayout, - QMessageBox, - QComboBox, - QSpinBox, - QDoubleSpinBox, - QCheckBox, - QFormLayout, QSizePolicy, - QTextEdit, - QProgressBar, - QFileDialog, - QDialog, + QSpinBox, QStatusBar, + QTextEdit, + QVBoxLayout, + QWidget, ) -from PyQt6.QtGui import QPixmap, QGuiApplication, QResizeEvent -from PyQt6.QtCore import ( - Qt, - QThread, - pyqtSignal, - QSettings, - QTimer, - QRunnable, - QThreadPool, - QObject, -) -import replicate -from dotenv import load_dotenv -from token_count import TokenCount -load_dotenv() +from utils import ImageGeneratorThread, ImageLoader, TokenCounter -class ImageLoaderSignals(QObject): - finished = pyqtSignal(list) - -class ImageLoader(QRunnable): - def __init__(self, folder_path): - super().__init__() - self.folder_path = folder_path - self.signals = ImageLoaderSignals() - - def run(self): - images = [] - for filename in os.listdir(self.folder_path): - if filename.lower().endswith((".png", ".jpg", ".jpeg", ".webp")): - file_path = os.path.join(self.folder_path, filename) - mod_time = os.path.getmtime(file_path) - images.append((file_path, mod_time)) - images.sort(key=lambda x: x[1], reverse=True) - self.signals.finished.emit([img[0] for img in images]) - -class ImageGeneratorThread(QThread): - finished = pyqtSignal(list) - error = pyqtSignal(str) - - def __init__(self, params): - super().__init__() - self.params = params - - def run(self): - try: - output = replicate.run( - "rtuszik/fluxlyptus:4e304b52ad6745623fb29f3250d89df23ac38b42734887d9e0a4b3a31c648472", - input=self.params, - ) - self.finished.emit(output) - except Exception as e: - self.error.emit(str(e)) - -class TokenCounter(QWidget): - def __init__(self, text_edit, *args, **kwargs): - super().__init__(*args, **kwargs) - self.text_edit = text_edit - self.tc = TokenCount(model_name="gpt-3.5-turbo") - - layout = QVBoxLayout(self) - self.token_count_label = QLabel("Tokens: 0") - self.warning_label = QLabel() - self.warning_label.setStyleSheet("color: orange;") - self.warning_label.hide() - - layout.addWidget(self.token_count_label) - layout.addWidget(self.warning_label) - - self.text_edit.textChanged.connect(self.update_count) - - def update_count(self): - text = self.text_edit.toPlainText() - token_count = self.tc.num_tokens_from_string(text) - self.token_count_label.setText(f"Tokens: {token_count}") - - if token_count > 77: - self.warning_label.setText("Warning: Tokens beyond 77 will be ignored") - self.warning_label.show() - else: - self.warning_label.hide() class ImageViewer(QDialog): def __init__(self, pixmap, parent=None): @@ -131,22 +55,20 @@ def initUI(self): self.updateImage() def updateImage(self): - if self.save_button: - button_height = self.save_button.height() - else: - button_height = 0 - - scaled_pixmap = self.original_pixmap.scaled( - self.width(), - self.height() - button_height, - Qt.AspectRatioMode.KeepAspectRatio, - Qt.TransformationMode.SmoothTransformation, - ) - self.image_label.setPixmap(scaled_pixmap) + if self.image_label: + button_height = self.save_button.height() if self.save_button else 0 + scaled_pixmap = self.original_pixmap.scaled( + self.width(), + self.height() - button_height, + Qt.AspectRatioMode.KeepAspectRatio, + Qt.TransformationMode.SmoothTransformation, + ) + self.image_label.setPixmap(scaled_pixmap) - def resizeEvent(self, event: QResizeEvent): - self.updateImage() - super().resizeEvent(event) + def resizeEvent(self, a0: QResizeEvent | None) -> None: + if a0 is not None: + self.updateImage() + super().resizeEvent(a0) def saveImage(self): file_name, _ = QFileDialog.getSaveFileName( @@ -155,6 +77,7 @@ def saveImage(self): if file_name: self.original_pixmap.save(file_name) + class ImagePreviewWidget(QLabel): def __init__(self, pixmap, file_path, parent=None): super().__init__(parent) @@ -171,25 +94,27 @@ def __init__(self, pixmap, file_path, parent=None): self.setAlignment(Qt.AlignmentFlag.AlignCenter) self.setStyleSheet(""" QLabel { - border: 2px solid + border: 2px solid #555555; border-radius: 10px; padding: 5px; margin: 5px; } QLabel:hover { - border-color: + border-color: #0078d7; } """) self.setMinimumSize(310, 310) - def mousePressEvent(self, event): - if event.button() == Qt.MouseButton.LeftButton: + def mousePressEvent(self, ev): + if ev.button() == Qt.MouseButton.LeftButton: viewer = ImageViewer(self.original_pixmap, self.parent()) viewer.exec() + class ImageGeneratorGUI(QMainWindow): - def __init__(self): + def __init__(self, image_generator): super().__init__() + self.image_generator = image_generator self.settings = QSettings("rtuszik", "Flux-Dev-Lora-GUI") self.threadpool = QThreadPool() self.current_thread = None @@ -217,64 +142,64 @@ def initUI(self): def getStyleSheet(self): return """ QMainWindow, QWidget { - background-color: - color: + background-color: #2b2b2b; + color: #f0f0f0; font-family: 'Arial', 'Sans-Serif'; font-size: 13px; } QLineEdit, QTextEdit, QComboBox, QSpinBox, QDoubleSpinBox { - background-color: - border: 1px solid - borderradius: 5px; + background-color: #3c3c3c; + border: 1px solid #555555; + border-radius: 5px; padding: 5px; - color: + color: #f0f0f0; width: 100%; } QPushButton { - background-color: + background-color: #5c5c5c; color: white; border: none; - borderradius: 5px; + border-radius: 5px; padding: 8px 16px; font-weight: 500; min-height: 30px; width: 100%; } QPushButton:hover { - background-color: + background-color: #6c6c6c; } QPushButton:pressed { - background-color: - border: 1px solid + background-color: #4c4c4c; + border: 1px solid #333333; } QLabel { - color: + color: #f0f0f0; } QScrollArea { border: none; - background-color: + background-color: #3c3c3c; } QCheckBox { - spacin: 5px; - color: + spacing: 5px; + color: #f0f0f0; } QCheckBox::indicator { width: 18px; height: 18px; } QCheckBox::indicator:unchecked { - border 2px solid - background-color: + border: 2px solid #888888; + background-color: #3c3c3c; } QCheckBox::indicator:checked { - border: 2px solid - background-color: + border: 2px solid #0078d7; + background-color: #0078d7; } """ def setupMainWidget(self): main_widget = QWidget() - main_layout = QVBoxLayout(main_widget) # Corrected line + main_layout = QVBoxLayout(main_widget) main_layout.setContentsMargins(20, 20, 20, 20) main_layout.setSpacing(20) self.setCentralWidget(main_widget) @@ -406,8 +331,6 @@ def setupBottomPanel(self): self.interrupt_button = QPushButton("Interrupt Generation") self.interrupt_button.clicked.connect(self.interrupt_generation) self.interrupt_button.setEnabled(False) - self.bottom_layout - self.bottom_layout.addWidget(self.interrupt_button) def setupStatusBar(self): @@ -450,8 +373,10 @@ def toggle_view(self): def updateGallery(self, image_paths=None): if image_paths is None: image_paths = [ - self.gallery_layout.itemAt(i).widget().file_path + item.widget().file_path for i in range(self.gallery_layout.count()) + if (item := self.gallery_layout.itemAt(i)) + and isinstance(item.widget(), ImagePreviewWidget) ] sorted_images = sorted( @@ -459,36 +384,39 @@ def updateGallery(self, image_paths=None): ) existing_images = { - self.gallery_layout.itemAt(i).widget().file_path + item.widget().file_path for i in range(self.gallery_layout.count()) + if (item := self.gallery_layout.itemAt(i)) + and isinstance(item.widget(), ImagePreviewWidget) } - images_to_add = [path for path in sorted_images if path not in existing_images] - - for path in images_to_add: - pixmap = QPixmap(path) - preview = ImagePreviewWidget(pixmap, path) - if self.is_grid_view: - row = self.gallery_layout.count() // 3 - col = self.gallery_layout.count() % 3 - else: - row = self.gallery_layout.count() - col = 0 - self.gallery_layout.addWidget(preview, row, col) + for path in sorted_images: + if path not in existing_images: + pixmap = QPixmap(path) + preview = ImagePreviewWidget(pixmap, path) + if self.is_grid_view: + row = self.gallery_layout.count() // 3 + col = self.gallery_layout.count() % 3 + else: + row = self.gallery_layout.count() + col = 0 + self.gallery_layout.addWidget(preview, row, col) for i in range(self.gallery_layout.count()): - self.gallery_layout.itemAt(i).widget().show() + item = self.gallery_layout.itemAt(i) + if item and isinstance(item.widget(), ImagePreviewWidget): + item.widget().show() - self.gallery_scroll.verticalScrollBar().setValue( - self.gallery_scroll.verticalScrollBar().minimum() - ) + scrollbar = self.gallery_scroll.verticalScrollBar() + if scrollbar: + scrollbar.setValue(scrollbar.minimum()) def clearGallery(self): for i in reversed(range(self.gallery_layout.count())): - widget = self.gallery_layout.itemAt(i).widget() - if widget is not None: - widget.hide() - self.gallery_layout.removeWidget(widget) + item = self.gallery_layout.itemAt(i) + if item and isinstance(item.widget(), ImagePreviewWidget): + item.widget().hide() + self.gallery_layout.removeWidget(item.widget()) def center(self): primary_screen = QGuiApplication.primaryScreen() @@ -541,6 +469,9 @@ def display_images(self, image_urls): def add_metadata_to_image(self, image_path, prompt): try: + from PIL import Image + from PIL.PngImagePlugin import PngInfo + with Image.open(image_path) as img: if img.format == "PNG": metadata = PngInfo() @@ -548,7 +479,7 @@ def add_metadata_to_image(self, image_path, prompt): img.save(image_path, pnginfo=metadata) elif img.format in ["JPEG", "WEBP"]: exif = img.getexif() - exif[0x9286] = prompt + exif[0x9286] = prompt # 0x9286 is the UserComment EXIF tag img.save(image_path, exif=exif) except Exception as e: print(f"Error adding metadata to {image_path}: {str(e)}") @@ -585,9 +516,11 @@ def loadSettings(self): "save_directory", os.path.expanduser("~/Downloads/replicate") ) ) - self.save_metadata_checkbox.setChecked( - self.settings.value("save_metadata", False, type=bool) - ) + if self.save_metadata_checkbox: + self.save_metadata_checkbox.setChecked( + self.settings.value("save_metadata", False, type=bool) + ) + self.loadImagesAsync() def saveSettings(self): @@ -649,7 +582,7 @@ def generate_images(self): self.progress_bar.show() self.generate_button.setEnabled(False) - self.current_thread = ImageGeneratorThread(params) + self.current_thread = ImageGeneratorThread(self.image_generator, params) self.current_thread.finished.connect(self.display_images) self.current_thread.error.connect(self.show_error) self.current_thread.start() @@ -671,14 +604,12 @@ def show_error(self, error_message): def clear_images(self): for i in reversed(range(self.gallery_layout.count())): - self.gallery_layout.itemAt(i).widget().setParent(None) + item = self.gallery_layout.itemAt(i) + if item and isinstance(item.widget(), ImagePreviewWidget): + widget = item.widget() + widget.setParent(None) - def closeEvent(self, event): + def closeEvent(self, a0): self.saveSettings() - super().closeEvent(event) - -if __name__ == "__main__": - app = QApplication([]) - ex = ImageGeneratorGUI() - ex.show() - app.exec() + super().closeEvent(a0) + self.interrupt_button.setEnabled(False) diff --git a/src/image_generator.py b/src/image_generator.py new file mode 100644 index 0000000..6e4a73d --- /dev/null +++ b/src/image_generator.py @@ -0,0 +1,18 @@ +import replicate +from dotenv import load_dotenv + +load_dotenv() + +class ImageGenerator: + def __init__(self): + self.model = "rtuszik/fluxlyptus:4e304b52ad6745623fb29f3250d89df23ac38b42734887d9e0a4b3a31c648472" + + def generate_images(self, params): + try: + output = replicate.run(self.model, input=params) + return output + except Exception as e: + raise ImageGenerationError(f"Error generating images: {str(e)}") + +class ImageGenerationError(Exception): + pass \ No newline at end of file diff --git a/src/main.py b/src/main.py new file mode 100644 index 0000000..740d7ce --- /dev/null +++ b/src/main.py @@ -0,0 +1,16 @@ +from PyQt6.QtWidgets import QApplication + +from gui import ImageGeneratorGUI +from image_generator import ImageGenerator + + +def main(): + app = QApplication([]) + generator = ImageGenerator() + window = ImageGeneratorGUI(generator) + window.show() + app.exec() + + +if __name__ == "__main__": + main() diff --git a/src/utils.py b/src/utils.py new file mode 100644 index 0000000..28e37ea --- /dev/null +++ b/src/utils.py @@ -0,0 +1,74 @@ +import os +from PyQt6.QtCore import QRunnable, QThread, pyqtSignal, QObject +from PyQt6.QtWidgets import QWidget, QVBoxLayout, QLabel + +class ImageLoaderSignals(QObject): + finished = pyqtSignal(list) + +class ImageLoader(QRunnable): + def __init__(self, folder_path): + super().__init__() + self.folder_path = folder_path + self.signals = ImageLoaderSignals() + + def run(self): + images = [] + for filename in os.listdir(self.folder_path): + if filename.lower().endswith((".png", ".jpg", ".jpeg", ".webp")): + file_path = os.path.join(self.folder_path, filename) + mod_time = os.path.getmtime(file_path) + images.append((file_path, mod_time)) + images.sort(key=lambda x: x[1], reverse=True) + self.signals.finished.emit([img[0] for img in images]) + +class ImageGeneratorThread(QThread): + finished = pyqtSignal(list) + error = pyqtSignal(str) + + def __init__(self, image_generator, params): + super().__init__() + self.image_generator = image_generator + self.params = params + + def run(self): + try: + output = self.image_generator.generate_images(self.params) + self.finished.emit(output) + except Exception as e: + self.error.emit(str(e)) + +class TokenCount: + def __init__(self, model_name): + self.model_name = model_name + + def num_tokens_from_string(self, string: str) -> int: + # This is a simplified implementation. You might want to use a proper tokenizer here. + return len(string.split()) + +class TokenCounter(QWidget): + def __init__(self, text_edit, *args, **kwargs): + super().__init__(*args, **kwargs) + self.text_edit = text_edit + self.tc = TokenCount(model_name="gpt-3.5-turbo") + + layout = QVBoxLayout(self) + self.token_count_label = QLabel("Tokens: 0") + self.warning_label = QLabel() + self.warning_label.setStyleSheet("color: orange;") + self.warning_label.hide() + + layout.addWidget(self.token_count_label) + layout.addWidget(self.warning_label) + + self.text_edit.textChanged.connect(self.update_count) + + def update_count(self): + text = self.text_edit.toPlainText() + token_count = self.tc.num_tokens_from_string(text) + self.token_count_label.setText(f"Tokens: {token_count}") + + if token_count > 77: + self.warning_label.setText("Warning: Tokens beyond 77 will be ignored") + self.warning_label.show() + else: + self.warning_label.hide()