Repository: bnsreenu/digitalsreeni-image-annotator Branch: master Commit: 9f3f112ee393 Files: 34 Total size: 467.9 KB Directory structure: gitextract_jmgtmxdc/ ├── LICENSE ├── README.md ├── Release Notes 0.8.12.md ├── data/ │ ├── YOLO11n-model-yaml/ │ │ ├── coco8.yaml │ │ └── download_YOLO_models.txt │ └── download_SAM_models.txt ├── requirements.txt ├── setup.py └── src/ └── digitalsreeni_image_annotator/ ├── __init__.py ├── annotation_statistics.py ├── annotation_utils.py ├── annotator_window.py ├── coco_json_combiner.py ├── constants.py ├── dataset_splitter.py ├── default_stylesheet.py ├── dicom_converter.py ├── export_formats.py ├── help_window.py ├── image_augmenter.py ├── image_label.py ├── image_patcher.py ├── import_formats.py ├── main.py ├── project_details.py ├── project_search.py ├── sam_utils.py ├── slice_registration.py ├── snake_game.py ├── soft_dark_stylesheet.py ├── stack_interpolator.py ├── stack_to_slices.py ├── utils.py └── yolo_trainer.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: LICENSE ================================================ MIT License Copyright (c) 2024 Dr. Sreenivas Bhattiprolu Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -------------------------------------------------------------------------------- CITATION REQUEST: If you use this software in your research, please cite it as follows: Bhattiprolu, S. (2024). Image Annotator [Computer software]. https://github.com/bnsreenu/digitalsreeni-image-annotator BibTeX: @software{image_annotator, author = {Bhattiprolu, Sreenivas}, title = {Image Annotator}, year = {2024}, url = {https://github.com/bnsreenu/digitalsreeni-image-annotator} } While not required by the license, citation is appreciated and helps support the continued development and maintenance of this software. ================================================ FILE: README.md ================================================ # DigitalSreeni Image Annotator and Toolkit    A powerful and user-friendly tool for annotating images with polygons and rectangles, built with PyQt5. Now with additional supporting tools for comprehensive image processing and dataset management. ## Support the Project If you find this project helpful, consider supporting it: [](https://www.paypal.com/donate/?business=FGQL3CNJGJP9C&no_recurring=0&item_name=If+you+find+this+Image+Annotator+project+helpful%2C+consider+supporting+it%3A¤cy_code=USD)  ## Watch the demo (of v0.8.0): [](https://youtu.be/aArn1f1YIQk) @DigitalSreeni Dr. Sreenivas Bhattiprolu ## Features - Semi-automated annotations with SAM-2 assistance (Segment Anything Model) — Because who doesn't love a helpful AI sidekick? - Manual annotations with polygons and rectangles — For when you want to show SAM-2 who's really in charge. - Paint brush and Eraser tools with adjustable pen sizes (use - and = on your keyboard) - Merge annotations - For when SAM-2's guesswork needs a little human touch. - Save and load projects for continued work. - Save As... and Autosave functionality. - A secret game, for when you are bored. - Import existing COCO JSON annotations with images. - Export annotations to various formats (COCO JSON, YOLO v8/v11, Labeled images, Semantic labels, Pascal VOC). - Handle multi-dimensional images (TIFF stacks and CZI files). - Zoom and pan for detailed annotations. - Support for multiple classes with customizable colors. - User-friendly interface with intuitive controls. - Change the application font size on the fly — Make your annotations as big or small as your caffeine level requires. - Dark mode for those late-night annotation marathons — Who needs sleep when you have dark mode? - Pick appropriate pre-trained SAM2 model for flexible and improved semi-automated annotations. - Change the class of an annotation to a different class. - Turn visibility of a class ON and OFF. - YOLO (beta) training using current annotations and loading trained model to segment images. - Area measurements for annotations displayed next to the Annotation name. - Sort annotations by name/number or area. - Additional supporting tools: - Annotation statistics for current annotations - COCO JSON combiner - Dataset splitter - Stack to slices converter - Image patcher - Image augmenter - Project Details: View and edit project metadata, including creation date, last modified date, image information, and custom notes. - Advanced Project Search: Search through multiple projects using complex queries with logical operators (AND, OR) and parentheses. - Slice Registration - Align image slices in a stack with multiple registration methods - Support for various reference frames and transformation types - Stack Interpolation - Adjust Z-spacing in image stacks - Multiple interpolation methods with memory-efficient processing - DICOM Converter - Convert DICOM files to TIFF format (single stack or individual slices) - Preserve metadata and physical dimensions - Export metadata to JSON for reference ## Operating System Requirements This application is built using PyQt5 and has been tested on macOS and Windows. It may experience compatibility issues on Linux systems, particularly related to the XCB plugin for PyQt5. Extensive testing on Linux systems has not been done yet. ## Installation ### Watch the installation walkthough video: [](https://youtu.be/VI6V95eUUpY) You can install the DigitalSreeni Image Annotator directly from PyPI: ```bash pip install digitalsreeni-image-annotator ``` The application uses the Ultralytics library, so there's no need to separately install SAM2 or PyTorch, or download SAM2 models manually. ## Usage 1. Run the DigitalSreeni Image Annotator application: ```bash digitalsreeni-image-annotator ``` or ```bash sreeni ``` or ```bash python -m digitalsreeni_image_annotator.main ``` 2. Using the application: - Click "New Project" or use Ctrl+N to start a new project. - Use "Add New Images" to import images, including TIFF stacks and CZI files. - Add classes using the "Add Classes" button. - Select a class and use the Polygon or Rectangle or Paint Brush tool to create manual annotations. - To use SAM2-assisted annotation: - Select a model from the "Pick a SAM Model" dropdown. It's recommended to use smaller models like SAM2 tiny or SAM2 small. SAM2 large is not recommended as it may crash the application on systems with limited resources. - Note: When you select a model for the first time, the application needs to download it. This process may take a few seconds to a minute, depending on your internet connection speed. Subsequent uses of the same model will be faster as it will already be cached locally, in your working directory. - Click the "SAM-Assisted" button to activate the tool. - Draw a rectangle around objects of interest to allow SAM2 to automatically detect objects. - Note that SAM2 provides various outputs with different scores, and only the top-scoring region will be displayed. If the desired result isn't achieved on the first try, draw again. - For low-quality images where SAM2 may not auto-detect objects, manual tools may be necessary. - When SAM2 auto-detect partial objects, use polygon or paint brush tools to manually define the remaining region and use the Merge tool to combine both annotations into one. - When SAM2 over-annotates objects, extending the annotation beyond object's boundaries, use the Eraser tool to clean up the edges. - Both paint brush and eraser tools can be adjusted for pen size by using - or = keys on your keyboard. - Edit existing annotations by double-clicking on them. - Edit existing annotations using the Eraser tool. Adjust the eraser size by using - or = keys on your keyboard. - Merge connected annotations by selecting them from the Annotations list and clicking the Merge button. - Change the class of an annotation to a different class. - Turn visibility of a class ON and OFF. - Use YOLO (beta) training with current annotations and load the trained model to segment images and convert segmentations to annotations. (Currently not implemented for slices or stacks, just single images.) - Accept/reject one or select class predictions at a time to add them as annotations. - View area measurements for annotations displayed next to the Annotation name. - Sort annotations by name/number or area. - Save your project using "Save Project" or Ctrl+S. Alternatively, you can use Save As... to save the project with a different name. - Use "Open Project" or Ctrl+O to load a previously saved project. - Click "Import Annotations with Images" to load existing COCO JSON annotations along with their images. - Use "Export Annotations" to save annotations in various formats (COCO JSON, YOLO v8/v11, Labeled images, Semantic labels, Pascal VOC). - Note: YOLO export (and import) is now compatible with YOLOv11 structure. (Project directory includes data.yaml, train, and valid directories, with train and valid both having images and labels subdirectories.) - Project Details: - Access project details by selecting "Project Details" from the Project menu. - View project metadata such as creation date, last modified date, and image information. - Add or edit custom project notes. - Project details are automatically saved when you make changes to the notes. - Advanced Project Search: - Access the search functionality by selecting "Search Projects" from the Project menu. - Search through multiple projects using complex queries. - Use logical operators (AND, OR) and parentheses for advanced search criteria. - Search covers project name, class names, image names, and project notes. - Example queries: - "cells AND dog": Find projects containing both "cells" and "dog" - "cells OR bacteria": Find projects containing either "cells" or "bacteria" - "cells AND (dog OR monkey)": Find projects containing "cells" and either "dog" or "monkey" - "(project1 OR project2) AND (cells OR bacteria)": More complex nested queries - Double-click on search results to open the corresponding project. - Access additional tools under the Tools menu bar: - Annotation Statistics - COCO JSON Combiner - Dataset Splitter - Stack to Slices Converter - Image Patcher - Image Augmenter - Each tool opens a separate UI to guide you through the respective task. - Access the help documentation by clicking the "Help" button or pressing F1. - Explore the interface – you might stumble upon some hidden gems and secret features! 3. Keyboard shortcuts: - Ctrl + N: Create a new project - Ctrl + O: Open an existing project - Ctrl + S: Save the current project - Ctrl + W: Close the current project - Ctrl + Shift + S: Open Annotation Statistics - F1: Open the help window - Ctrl + Wheel: Zoom in/out - Hold Ctrl and drag: Pan the image - Esc: Cancel current annotation, exit edit mode, or exit SAM-assisted annotation - Enter: Finish current annotation, exit edit mode, or accept SAM-generated mask - Up/Down Arrow Keys: Navigate through slices in multi-dimensional images - - and =: Adjust pen size for paint brush and eraser tools ## Known Issues and Bug Fixes - The application may not work correctly on Linux systems. Extensive testing has not been done yet. - When loading a YOLO model trained on different classes compared to the loaded YAML file, the application now gives a message to the user about the mismatch instead of crashing. - Various other bugs have been addressed to improve overall stability and performance. ## Development For development purposes, you can clone the repository and install it in editable mode: 1. Clone the repository: ```bash git clone https://github.com/bnsreenu/digitalsreeni-image-annotator.git cd digitalsreeni-image-annotator ``` 2. Create a virtual environment (optional but recommended): ```bash python -m venv venv source venv/bin/activate # On Windows, use `venv\Scripts\activate` ``` 3. Install the package and its dependencies in editable mode: ```bash pip install -e . ``` ## Contributing Contributions are welcome! Please feel free to submit a Pull Request. 1. Fork the repository 2. Create your feature branch (`git checkout -b feature/AmazingFeature`) 3. Commit your changes (`git commit -m 'Add some AmazingFeature'`) 4. Push to the branch (`git push origin feature/AmazingFeature`) 5. Open a Pull Request ## License This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details. ## Acknowledgments - Thanks to all my [YouTube](http://www.youtube.com/c/DigitalSreeni) subscribers who inspired me to work on this project - Inspired by the need for efficient image annotation in computer vision tasks ## Contact Dr. Sreenivas Bhattiprolu - [@DigitalSreeni](https://twitter.com/DigitalSreeni) Project Link: [https://github.com/bnsreenu/digitalsreeni-image-annotator](https://github.com/bnsreenu/digitalsreeni-image-annotator) ## Citing If you use this software in your research, please cite it as follows: Bhattiprolu, S. (2024). DigitalSreeni Image Annotator [Computer software]. https://github.com/bnsreenu/digitalsreeni-image-annotator ```bibtex @software{digitalsreeni_image_annotator, author = {Bhattiprolu, Sreenivas}, title = {DigitalSreeni Image Annotator}, year = {2024}, url = {https://github.com/bnsreenu/digitalsreeni-image-annotator} } ``` ================================================ FILE: Release Notes 0.8.12.md ================================================ # Release Notes ## Version 0.8.12 ### New Features and Enhancements - Same as version 0.8.9 except changed the requirements file to define specific version numbers for the libraies used. - The following bug fixes and optimizations correspond to version 0.8.9 ### Bug Fixes and Optimizations 1. **Project Corruption Prevention** - Fixed critical issue where projects could become corrupted if application was terminated during loading - Disabled auto-save functionality during project loading process - Enhanced project loading stability for large datasets - Protected project integrity when handling multiple classes and images ### Notes - All existing tools continue to support both Windows and macOS operating systems - Improved reliability of project file handling - Critical update recommended for users working with large projects ================================================ FILE: data/YOLO11n-model-yaml/coco8.yaml ================================================ # Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..] path: ../datasets/coco8 # dataset root dir train: images/train # train images (relative to 'path') 4 images val: images/val # val images (relative to 'path') 4 images test: # test images (optional) # Classes names: 0: person 1: bicycle 2: car 3: motorcycle 4: airplane 5: bus 6: train 7: truck 8: boat 9: traffic light 10: fire hydrant 11: stop sign 12: parking meter 13: bench 14: bird 15: cat 16: dog 17: horse 18: sheep 19: cow 20: elephant 21: bear 22: zebra 23: giraffe 24: backpack 25: umbrella 26: handbag 27: tie 28: suitcase 29: frisbee 30: skis 31: snowboard 32: sports ball 33: kite 34: baseball bat 35: baseball glove 36: skateboard 37: surfboard 38: tennis racket 39: bottle 40: wine glass 41: cup 42: fork 43: knife 44: spoon 45: bowl 46: banana 47: apple 48: sandwich 49: orange 50: broccoli 51: carrot 52: hot dog 53: pizza 54: donut 55: cake 56: chair 57: couch 58: potted plant 59: bed 60: dining table 61: toilet 62: tv 63: laptop 64: mouse 65: remote 66: keyboard 67: cell phone 68: microwave 69: oven 70: toaster 71: sink 72: refrigerator 73: book 74: clock 75: vase 76: scissors 77: teddy bear 78: hair drier 79: toothbrush # Download script/URL (optional) download: https://github.com/ultralytics/assets/releases/download/v0.0.0/coco8.zip ================================================ FILE: data/YOLO11n-model-yaml/download_YOLO_models.txt ================================================ https://docs.ultralytics.com/tasks/segment/#models Recommended model: https://github.com/ultralytics/assets/releases/download/v8.3.0/yolo11n-seg.pt ================================================ FILE: data/download_SAM_models.txt ================================================ It is recommended to pre-download SAM models and place them in your working director - the directory from where you are starting this application. This avoids downloading the models multiple times. Download models from: https://docs.ultralytics.com/models/sam-2/ Direct Download links: Tiny model: https://github.com/ultralytics/assets/releases/download/v8.2.0/sam2_t.pt Small: https://github.com/ultralytics/assets/releases/download/v8.2.0/sam2_s.pt Base: https://github.com/ultralytics/assets/releases/download/v8.2.0/sam2_b.pt Large: https://github.com/ultralytics/assets/releases/download/v8.2.0/sam2_l.pt Be cautious with the large model as it demands higher computing and memory resources from your system. ================================================ FILE: requirements.txt ================================================ PyQt5==5.15.11 Pillow==11.0.0 numpy==2.1.3 tifffile==2023.3.15 czifile==2019.7.2 opencv-python==4.10.0.84 pyyaml==6.0.2 scikit-image==0.24.0 ultralytics==8.3.27 plotly==5.24.1 shapely==2.0.6 pystackreg==0.2.8 pydicom==3.0.1 ================================================ FILE: setup.py ================================================ """ Setup file for the DigitalSreeni Image Annotator package. @DigitalSreeni Dr. Sreenivas Bhattiprolu """ from setuptools import setup, find_packages with open("README.md", "r", encoding="utf-8") as fh: long_description = fh.read() setup( name="digitalsreeni-image-annotator", version="0.8.12", # Updated version number author="Dr. Sreenivas Bhattiprolu", author_email="digitalsreeni@gmail.com", description="A tool for annotating images using manual and automated tools, supporting multi-dimensional images and SAM2-assisted annotations", long_description=long_description, long_description_content_type="text/markdown", url="https://github.com/bnsreenu/digitalsreeni-image-annotator", packages=find_packages(where="src"), package_dir={"": "src"}, classifiers=[ "Development Status :: 3 - Alpha", "Intended Audience :: Science/Research", "Intended Audience :: Developers", "License :: OSI Approved :: MIT License", "Operating System :: OS Independent", "Programming Language :: Python :: 3.10", ], python_requires=">=3.10", install_requires=[ "PyQt5==5.15.11", "numpy==2.1.3", "Pillow==11.0.0", "tifffile==2023.3.15", "czifile==2019.7.2", "opencv-python==4.10.0.84", "pyyaml==6.0.2", "scikit-image==0.24.0", "ultralytics==8.3.27", "plotly==5.24.1", "shapely==2.0.6", "pystackreg==0.2.8", "pydicom==3.0.1" ], entry_points={ "console_scripts": [ "digitalsreeni-image-annotator=digitalsreeni_image_annotator.main:main", "sreeni=digitalsreeni_image_annotator.main:main", ], }, ) ================================================ FILE: src/digitalsreeni_image_annotator/__init__.py ================================================ """ Image Annotator =============== A tool for annotating images with polygons and rectangles. This package provides a GUI application for image annotation, supporting polygon and rectangle annotations in a COCO-compatible format. @DigitalSreeni Dr. Sreenivas Bhattiprolu """ __version__ = "0.8.12" __author__ = "Dr. Sreenivas Bhattiprolu" from .annotator_window import ImageAnnotator from .image_label import ImageLabel from .utils import calculate_area, calculate_bbox from .sam_utils import SAMUtils __all__ = ['ImageAnnotator', 'ImageLabel', 'calculate_area', 'calculate_bbox', 'SAMUtils'] # Add 'SAMUtils' to this list ================================================ FILE: src/digitalsreeni_image_annotator/annotation_statistics.py ================================================ import plotly.graph_objects as go from plotly.subplots import make_subplots from PyQt5.QtWidgets import QDialog, QVBoxLayout, QTextBrowser, QPushButton, QHBoxLayout from PyQt5.QtCore import Qt import tempfile import os import webbrowser class AnnotationStatisticsDialog(QDialog): def __init__(self, parent=None): super().__init__(parent) self.setWindowTitle("Annotation Statistics") self.setGeometry(100, 100, 600, 400) self.setWindowFlags(self.windowFlags() | Qt.Window) self.initUI() def initUI(self): layout = QVBoxLayout() self.text_browser = QTextBrowser() layout.addWidget(self.text_browser) button_layout = QHBoxLayout() self.show_plot_button = QPushButton("Show Interactive Plot") self.show_plot_button.clicked.connect(self.show_interactive_plot) button_layout.addWidget(self.show_plot_button) layout.addLayout(button_layout) self.setLayout(layout) self.plot_file = None def show_centered(self, parent): parent_geo = parent.geometry() self.move(parent_geo.center() - self.rect().center()) self.show() def generate_statistics(self, annotations): try: # Class distribution class_distribution = {} objects_per_image = {} total_objects = 0 for image, image_annotations in annotations.items(): objects_in_image = 0 for class_name, class_annotations in image_annotations.items(): class_count = len(class_annotations) class_distribution[class_name] = class_distribution.get(class_name, 0) + class_count objects_in_image += class_count total_objects += class_count objects_per_image[image] = objects_in_image avg_objects_per_image = total_objects / len(annotations) if annotations else 0 # Create plots fig = make_subplots(rows=2, cols=1, subplot_titles=("Class Distribution", "Objects per Image")) # Class distribution plot fig.add_trace(go.Bar(x=list(class_distribution.keys()), y=list(class_distribution.values()), name="Classes"), row=1, col=1) # Objects per image plot fig.add_trace(go.Bar( x=list(objects_per_image.keys()), y=list(objects_per_image.values()), name="Images", hovertext=[f"{img}: {count}" for img, count in objects_per_image.items()], hoverinfo="text" ), row=2, col=1) # Update layout fig.update_layout(height=800, title_text="Annotation Statistics") # Hide x-axis labels for the second subplot (Objects per Image) fig.update_xaxes(showticklabels=False, title_text="Images", row=2, col=1) # Update y-axis title for the second subplot fig.update_yaxes(title_text="Number of Objects", row=2, col=1) # Save the plot to a temporary HTML file with tempfile.NamedTemporaryFile(mode="w", suffix=".html", delete=False) as tmp: fig.write_html(tmp.name) self.plot_file = tmp.name # Display statistics in the text browser stats_text = f"Total objects: {total_objects}\n" stats_text += f"Average objects per image: {avg_objects_per_image:.2f}\n\n" stats_text += "Class distribution:\n" for class_name, count in class_distribution.items(): stats_text += f" {class_name}: {count}\n" self.text_browser.setPlainText(stats_text) except Exception as e: self.text_browser.setPlainText(f"An error occurred while generating statistics: {str(e)}") self.show_plot_button.setEnabled(False) def show_interactive_plot(self): if self.plot_file and os.path.exists(self.plot_file): webbrowser.open('file://' + os.path.realpath(self.plot_file)) else: self.text_browser.append("Error: Plot file not found.") def closeEvent(self, event): if self.plot_file and os.path.exists(self.plot_file): os.unlink(self.plot_file) super().closeEvent(event) def show_annotation_statistics(parent, annotations): dialog = AnnotationStatisticsDialog(parent) dialog.generate_statistics(annotations) dialog.show_centered(parent) return dialog ================================================ FILE: src/digitalsreeni_image_annotator/annotation_utils.py ================================================ from PyQt5.QtWidgets import QListWidgetItem from PyQt5.QtGui import QColor from PyQt5.QtCore import Qt class AnnotationUtils: @staticmethod def update_annotation_list(self, image_name=None): self.annotation_list.clear() current_name = image_name or self.current_slice or self.image_file_name annotations = self.all_annotations.get(current_name, {}) for class_name, class_annotations in annotations.items(): color = self.image_label.class_colors.get(class_name, QColor(Qt.white)) for i, annotation in enumerate(class_annotations, start=1): item_text = f"{class_name} - {i}" item = QListWidgetItem(item_text) item.setData(Qt.UserRole, annotation) item.setForeground(color) self.annotation_list.addItem(item) @staticmethod def update_slice_list_colors(self): for i in range(self.slice_list.count()): item = self.slice_list.item(i) slice_name = item.text() if slice_name in self.all_annotations and any(self.all_annotations[slice_name].values()): item.setForeground(QColor(Qt.green)) else: item.setForeground(QColor(Qt.black) if not self.dark_mode else QColor(Qt.white)) @staticmethod def update_annotation_list_colors(self, class_name=None, color=None): for i in range(self.annotation_list.count()): item = self.annotation_list.item(i) annotation = item.data(Qt.UserRole) if class_name is None or annotation['category_name'] == class_name: item_color = color if class_name else self.image_label.class_colors.get(annotation['category_name'], QColor(Qt.white)) item.setForeground(item_color) @staticmethod def load_image_annotations(self): self.image_label.annotations.clear() current_name = self.current_slice or self.image_file_name if current_name in self.all_annotations: self.image_label.annotations = self.all_annotations[current_name].copy() self.image_label.update() @staticmethod def save_current_annotations(self): current_name = self.current_slice or self.image_file_name if current_name: if self.image_label.annotations: self.all_annotations[current_name] = self.image_label.annotations.copy() elif current_name in self.all_annotations: del self.all_annotations[current_name] AnnotationUtils.update_slice_list_colors(self) @staticmethod def add_annotation_to_list(self, annotation): class_name = annotation['category_name'] color = self.image_label.class_colors.get(class_name, QColor(Qt.white)) annotations = self.image_label.annotations.get(class_name, []) item_text = f"{class_name} - {len(annotations)}" item = QListWidgetItem(item_text) item.setData(Qt.UserRole, annotation) item.setForeground(color) self.annotation_list.addItem(item) ================================================ FILE: src/digitalsreeni_image_annotator/annotator_window.py ================================================ import os import json from PyQt5.QtWidgets import (QMainWindow, QWidget, QVBoxLayout, QHBoxLayout, QPushButton, QFileDialog, QListWidget, QInputDialog, QLabel, QButtonGroup, QListWidgetItem, QScrollArea, QCheckBox, QSlider, QMenu, QMessageBox, QColorDialog, QDialog, QDoubleSpinBox, QGridLayout, QComboBox, QAbstractItemView, QProgressDialog, QApplication, QAction, QLineEdit, QTextEdit, QDialogButtonBox, QProgressBar) from PyQt5.QtGui import QPixmap, QColor, QIcon, QImage, QFont, QKeySequence, QPalette from PyQt5.QtCore import Qt, QThread, pyqtSignal import numpy as np from tifffile import TiffFile from czifile import CziFile import cv2 from datetime import datetime from .image_label import ImageLabel from .utils import calculate_area, calculate_bbox from .help_window import HelpWindow from .soft_dark_stylesheet import soft_dark_stylesheet from .default_stylesheet import default_stylesheet from .dataset_splitter import DatasetSplitterTool from .annotation_statistics import show_annotation_statistics from .coco_json_combiner import show_coco_json_combiner from .stack_to_slices import show_stack_to_slices from .image_patcher import show_image_patcher from .image_augmenter import show_image_augmenter from .slice_registration import SliceRegistrationTool from .sam_utils import SAMUtils from .snake_game import SnakeGame from .yolo_trainer import YOLOTrainer, TrainingInfoDialog, LoadPredictionModelDialog from .stack_interpolator import StackInterpolator from .dicom_converter import DicomConverter from shapely.geometry import Polygon, MultiPolygon, Point from shapely.ops import unary_union from shapely.validation import make_valid import shapely from .export_formats import ( export_coco_json, export_yolo_v4, export_yolo_v5plus, export_labeled_images, export_semantic_labels, export_pascal_voc_bbox, export_pascal_voc_both ) from .import_formats import import_coco_json, import_yolo_v4, import_yolo_v5plus from .import_formats import process_import_format import shutil import copy from ultralytics import SAM import warnings warnings.filterwarnings("ignore", category=UserWarning) class TrainingThread(QThread): progress_update = pyqtSignal(str) finished = pyqtSignal(object) def __init__(self, yolo_trainer, epochs, imgsz): super().__init__() self.yolo_trainer = yolo_trainer self.epochs = epochs self.imgsz = imgsz def run(self): try: results = self.yolo_trainer.train_model(epochs=self.epochs, imgsz=self.imgsz) self.finished.emit(results) except Exception as e: self.finished.emit(str(e)) class DimensionDialog(QDialog): def __init__(self, shape, file_name, parent=None, default_dimensions=None): super().__init__(parent) self.setWindowTitle("Assign Dimensions") layout = QVBoxLayout(self) # Add file name label file_name_label = QLabel(f"File: {file_name}") file_name_label.setWordWrap(True) layout.addWidget(file_name_label) # Add dimension assignment widgets dim_widget = QWidget() dim_layout = QGridLayout(dim_widget) self.combos = [] self.shape = shape dimensions = ['T', 'Z', 'C', 'S', 'H', 'W'] for i, dim in enumerate(shape): dim_layout.addWidget(QLabel(f"Dimension {i} (size {dim}):"), i, 0) combo = QComboBox() combo.addItems(dimensions) if default_dimensions and i < len(default_dimensions): combo.setCurrentText(default_dimensions[i]) dim_layout.addWidget(combo, i, 1) self.combos.append(combo) layout.addWidget(dim_widget) self.button = QPushButton("OK") self.button.clicked.connect(self.accept) layout.addWidget(self.button) self.setMinimumWidth(300) def get_dimensions(self): return [combo.currentText() for combo in self.combos] class ImageAnnotator(QMainWindow): def __init__(self): super().__init__() self.is_loading_project = False self.backup_project_path = None self.setWindowTitle("Image Annotator") self.setGeometry(100, 100, 1400, 800) self.central_widget = QWidget() self.setCentralWidget(self.central_widget) self.layout = QHBoxLayout(self.central_widget) self.create_menu_bar() # Initialize image_label early self.image_label = ImageLabel() self.image_label.set_main_window(self) # Initialize attributes self.current_image = None self.current_class = None self.image_file_name = "" self.all_annotations = {} self.all_images = [] self.image_paths = {} self.loaded_json = None self.class_mapping = {} self.editing_mode = False self.current_slice = None self.slices = [] self.current_stack = None self.image_dimensions = {} self.image_slices = {} self.image_shapes = {} # For paint brush and eraser self.paint_brush_size = 10 self.eraser_size = 10 # Initialize SAM utils self.current_sam_model = None self.sam_utils = SAMUtils() # Create sam_magic_wand_button self.sam_magic_wand_button = QPushButton("Magic Wand") self.sam_magic_wand_button.setCheckable(True) self.sam_magic_wand_button.setEnabled(False) # Initially disable the button # Initialize tool group self.tool_group = QButtonGroup(self) self.tool_group.setExclusive(False) # Font size control self.font_sizes = {"Small": 8, "Medium": 10, "Large": 12, "XL": 14, "XXL": 16} # Also, add the options in create_menu_bar method self.current_font_size = "Medium" # Dark mode control self.dark_mode = False # Default annotations sorting self.current_sort_method = "class" # Default sorting method # Setup UI components self.setup_ui() # Apply theme and font (this includes stylesheet and font size application) self.apply_theme_and_font() # Connect sam_magic_wand_button self.sam_magic_wand_button.clicked.connect(self.toggle_tool) self.class_list.itemChanged.connect(self.toggle_class_visibility) # YOLO Trainer self.yolo_trainer = None self.setup_yolo_menu() # Start in maximized mode self.showMaximized() def setup_ui(self): # Initialize the main layout self.central_widget = QWidget() self.setCentralWidget(self.central_widget) self.layout = QHBoxLayout(self.central_widget) # Initialize tool group self.tool_group = QButtonGroup(self) self.tool_group.setExclusive(False) # Setup UI components self.setup_sidebar() self.setup_image_area() self.setup_image_list() self.setup_slice_list() self.update_ui_for_current_tool() def update_window_title(self): base_title = "Image Annotator" if hasattr(self, 'current_project_file'): project_name = os.path.basename(self.current_project_file) project_name = os.path.splitext(project_name)[0] # Remove the file extension self.setWindowTitle(f"{base_title} - {project_name}") else: self.setWindowTitle(base_title) def new_project(self): self.remove_all_temp_annotations() # Remove temp annotations from the previous project project_file, _ = QFileDialog.getSaveFileName(self, "Create New Project", "", "Image Annotator Project (*.iap)") if project_file: # Ensure the file has the correct extension if not project_file.lower().endswith('.iap'): project_file += '.iap' self.current_project_file = project_file self.current_project_dir = os.path.dirname(project_file) # Create the images directory images_dir = os.path.join(self.current_project_dir, "images") os.makedirs(images_dir, exist_ok=True) # Clear existing data without showing messages self.clear_all(new_project=True, show_messages=False) # Prompt for initial project notes notes, ok = QInputDialog.getMultiLineText(self, "Project Notes", "Enter initial project notes:") if ok: self.project_notes = notes else: self.project_notes = "" self.project_creation_date = datetime.now().isoformat() # Save the empty project without showing a message self.save_project(show_message=False) # Keep only this message self.show_info("New Project", f"New project created at {self.current_project_file}") self.initialize_yolo_trainer() self.update_window_title() def show_project_search(self): from .project_search import show_project_search show_project_search(self) def open_project(self): print("open_project method called") # Debug print self.remove_all_temp_annotations() # Remove temp annotations from the previous project project_file, _ = QFileDialog.getOpenFileName(self, "Open Project", "", "Image Annotator Project (*.iap)") print(f"Selected project file: {project_file}") # Debug print if project_file: try: self.backup_project_before_open(project_file) self.open_specific_project(project_file) except Exception as e: self.restore_project_from_backup() QMessageBox.critical(self, "Error", f"An error occurred while opening the project: {str(e)}\n" f"The project file has been restored from backup.") else: print("No project file selected") # Debug print def backup_project_before_open(self, project_file): """Create a backup of the project file before opening it.""" import shutil import os timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") backup_dir = os.path.join(os.path.dirname(project_file), ".project_backups") os.makedirs(backup_dir, exist_ok=True) self.backup_project_path = os.path.join(backup_dir, f"{os.path.basename(project_file)}.{timestamp}.backup") shutil.copy2(project_file, self.backup_project_path) def restore_project_from_backup(self): """Restore the project file from its backup if available.""" if self.backup_project_path and os.path.exists(self.backup_project_path): try: shutil.copy2(self.backup_project_path, self.current_project_file) print(f"Project restored from backup: {self.backup_project_path}") except Exception as e: print(f"Failed to restore from backup: {str(e)}") def open_specific_project(self, project_file): print(f"Opening specific project: {project_file}") # Debug print if os.path.exists(project_file): try: self.is_loading_project = True # Set loading flag with open(project_file, 'r') as f: project_data = json.load(f) self.clear_all(show_messages=False) self.current_project_file = project_file self.current_project_dir = os.path.dirname(project_file) # Load project notes and metadata self.project_notes = project_data.get('notes', '') self.project_creation_date = project_data.get('creation_date', '') self.last_modified = project_data.get('last_modified', '') # Parse dates if self.project_creation_date: self.project_creation_date = datetime.fromisoformat(self.project_creation_date).strftime("%Y-%m-%d %H:%M:%S") if self.last_modified: self.last_modified = datetime.fromisoformat(self.last_modified).strftime("%Y-%m-%d %H:%M:%S") # Load all data without triggering auto-saves self.load_project_data(project_data) # Now save once after everything is loaded self.is_loading_project = False # Clear loading flag self.save_project(show_message=False) # Save once after loading self.initialize_yolo_trainer() self.update_window_title() print(f"Project opened successfully: {project_file}") QMessageBox.information(self, "Project Opened", f"Project opened successfully: {os.path.basename(project_file)}") except Exception as e: self.is_loading_project = False # Make sure to clear flag on error raise e else: print(f"Project file not found: {project_file}") QMessageBox.critical(self, "Error", f"Project file not found: {project_file}") def load_project_data(self, project_data): """Load project data without triggering auto-saves.""" # Load classes self.class_mapping.clear() self.image_label.class_colors.clear() for class_info in project_data.get('classes', []): self.add_class(class_info['name'], QColor(class_info['color'])) # Load images self.all_images = project_data.get('images', []) self.image_paths = project_data.get('image_paths', {}) # Load all annotations first self.all_annotations.clear() for image_info in project_data['images']: if image_info.get('is_multi_slice', False): for slice_info in image_info.get('slices', []): self.all_annotations[slice_info['name']] = slice_info['annotations'] else: self.all_annotations[image_info['file_name']] = image_info.get('annotations', {}) # Handle missing images missing_images = [] for image_info in project_data['images']: image_path = os.path.join(self.current_project_dir, "images", image_info['file_name']) if not os.path.exists(image_path): missing_images.append(image_info['file_name']) continue # Update image_paths self.image_paths[image_info['file_name']] = image_path if image_info.get('is_multi_slice', False): dimensions = image_info.get('dimensions', []) shape = image_info.get('shape', []) self.load_multi_slice_image(image_path, dimensions, shape) else: self.add_images_to_list([image_path]) # Update UI self.update_ui() # Handle missing images if any if missing_images: self.handle_missing_images(missing_images) # Select the first image if available if self.image_list.count() > 0: self.image_list.setCurrentRow(0) first_item = self.image_list.item(0) if first_item: self.switch_image(first_item) # Select the first class if available if self.class_list.count() > 0: self.class_list.setCurrentRow(0) self.on_class_selected() def handle_missing_images(self, missing_images): message = "The following images have annotations but were not found in the project directory:\n\n" message += "\n".join(missing_images[:10]) # Show first 10 missing images if len(missing_images) > 10: message += f"\n... and {len(missing_images) - 10} more." message += "\n\nWould you like to locate these images now?" reply = QMessageBox.question(self, "Missing Images", message, QMessageBox.Yes | QMessageBox.No, QMessageBox.Yes) if reply == QMessageBox.Yes: self.load_missing_images(missing_images) else: self.remove_missing_images(missing_images) def remove_missing_images(self, missing_images): for image_name in missing_images: # Remove from all_images self.all_images = [img for img in self.all_images if img['file_name'] != image_name] # Remove from image_paths self.image_paths.pop(image_name, None) # Remove from all_annotations self.all_annotations.pop(image_name, None) # If it's a multi-slice image, remove all related slices base_name = os.path.splitext(image_name)[0] if base_name in self.image_slices: for slice_name, _ in self.image_slices[base_name]: self.all_annotations.pop(slice_name, None) del self.image_slices[base_name] self.update_ui() QMessageBox.information(self, "Images Removed", f"{len(missing_images)} missing images and their annotations have been removed from the project.") def prompt_load_missing_images(self, missing_images): message = "The following images have annotations but were not found in the project directory:\n\n" message += "\n".join(missing_images[:10]) # Show first 10 missing images if len(missing_images) > 10: message += f"\n... and {len(missing_images) - 10} more." message += "\n\nWould you like to locate these images now?" reply = QMessageBox.question(self, "Load Missing Images", message, QMessageBox.Yes | QMessageBox.No, QMessageBox.Yes) if reply == QMessageBox.Yes: self.load_missing_images(missing_images) def load_missing_images(self, missing_images): files, _ = QFileDialog.getOpenFileNames(self, "Select Missing Images", "", "Image Files (*.png *.jpg *.bmp *.tif *.tiff *.czi)") if files: images_loaded = 0 for file_path in files: file_name = os.path.basename(file_path) if file_name in missing_images: dst_path = os.path.join(self.current_project_dir, "images", file_name) shutil.copy2(file_path, dst_path) self.image_paths[file_name] = dst_path # Add the image to all_images if it's not already there if not any(img['file_name'] == file_name for img in self.all_images): self.all_images.append({ "file_name": file_name, "height": 0, "width": 0, "id": len(self.all_images) + 1, "is_multi_slice": False }) images_loaded += 1 missing_images.remove(file_name) self.update_image_list() if images_loaded > 0: self.image_list.setCurrentRow(0) # Select the first image self.switch_image(self.image_list.item(0)) # Display the first image QMessageBox.information(self, "Images Loaded", f"Successfully copied and loaded {images_loaded} out of {len(files)} selected images.") # If there are still missing images, prompt again if missing_images: self.prompt_load_missing_images(missing_images) def update_image_list(self): self.image_list.clear() for image_info in self.all_images: self.image_list.addItem(image_info['file_name']) def select_class(self, index): if 0 <= index < self.class_list.count(): item = self.class_list.item(index) self.class_list.setCurrentItem(item) self.current_class = item.text() print(f"Selected class: {self.current_class}") else: print("Invalid class index") def close_project(self): if hasattr(self, 'current_project_file'): reply = QMessageBox.question(self, 'Close Project', "Do you want to save the current project before closing?", QMessageBox.Yes | QMessageBox.No | QMessageBox.Cancel) if reply == QMessageBox.Yes: self.remove_all_temp_annotations() # Remove temp annotations before saving self.save_project(show_message=False) # Save without showing a message elif reply == QMessageBox.Cancel: return # User cancelled the operation # Clear all data self.clear_all(new_project=True, show_messages=False) # Reset project-related attributes if hasattr(self, 'current_project_file'): del self.current_project_file if hasattr(self, 'current_project_dir'): del self.current_project_dir # Update the window title self.update_window_title() def delete_selected_class(self): selected_items = self.class_list.selectedItems() if not selected_items: QMessageBox.warning(self, "No Selection", "Please select a class to delete.") return class_name = selected_items[0].text() reply = QMessageBox.question(self, 'Delete Class', f"Are you sure you want to delete the class '{class_name}'?", QMessageBox.Yes | QMessageBox.No, QMessageBox.No) if reply == QMessageBox.Yes: self.delete_class(class_name) # Sreeni note: Implement this method to handle class deletion def check_missing_images(self): missing_images = [img['file_name'] for img in self.all_images if img['file_name'] not in self.image_paths or not os.path.exists(self.image_paths[img['file_name']])] if missing_images: self.prompt_load_missing_images(missing_images) def convert_to_serializable(self, obj): if isinstance(obj, np.integer): return int(obj) elif isinstance(obj, np.floating): return float(obj) elif isinstance(obj, np.ndarray): return obj.tolist() elif isinstance(obj, list): return [self.convert_to_serializable(item) for item in obj] elif isinstance(obj, dict): return {key: self.convert_to_serializable(value) for key, value in obj.items()} else: return obj def save_project(self, show_message=True): if not hasattr(self, 'current_project_file') or not self.current_project_file: self.current_project_file, _ = QFileDialog.getSaveFileName(self, "Save Project", "", "Image Annotator Project (*.iap)") if not self.current_project_file: return # User cancelled the save dialog self.current_project_dir = os.path.dirname(self.current_project_file) # Check if images are in the correct directory structure images_dir = os.path.join(self.current_project_dir, "images") os.makedirs(images_dir, exist_ok=True) images_to_copy = [] for file_name, src_path in self.image_paths.items(): dst_path = os.path.join(images_dir, file_name) if os.path.abspath(src_path) != os.path.abspath(dst_path): if not os.path.exists(dst_path): images_to_copy.append((file_name, src_path, dst_path)) if images_to_copy: reply = QMessageBox.question(self, 'Image Directory Structure', f"The project structure requires all images to be in an 'images' subdirectory. " f"{len(images_to_copy)} images need to be copied to the correct location. " f"Do you want to copy these images?", QMessageBox.Yes | QMessageBox.No, QMessageBox.Yes) if reply == QMessageBox.Yes: for file_name, src_path, dst_path in images_to_copy: try: shutil.copy2(src_path, dst_path) self.image_paths[file_name] = dst_path except Exception as e: QMessageBox.warning(self, "Copy Failed", f"Failed to copy {file_name}: {str(e)}") return else: QMessageBox.warning(self, "Save Cancelled", "Project cannot be saved without the correct directory structure.") return # Prepare image data images_data = [] for image_info in self.all_images: file_name = image_info['file_name'] image_data = { 'file_name': file_name, 'width': image_info['width'], 'height': image_info['height'], 'is_multi_slice': image_info['is_multi_slice'] } if image_data['is_multi_slice']: base_name_without_ext = os.path.splitext(file_name)[0] image_data['slices'] = [] for slice_name, _ in self.image_slices.get(base_name_without_ext, []): slice_data = { 'name': slice_name, 'annotations': self.convert_to_serializable(self.all_annotations.get(slice_name, {})) } image_data['slices'].append(slice_data) image_data['dimensions'] = self.convert_to_serializable(self.image_dimensions.get(base_name_without_ext, [])) image_data['shape'] = self.convert_to_serializable(self.image_shapes.get(base_name_without_ext, [])) else: image_data['annotations'] = {} for class_name, annotations in self.all_annotations.get(file_name, {}).items(): image_data['annotations'][class_name] = [ann.copy() for ann in annotations] images_data.append(image_data) # Create project data project_data = { 'classes': [ {'name': name, 'color': color.name()} for name, color in self.image_label.class_colors.items() ], 'images': images_data, 'image_paths': {k: v for k, v in self.image_paths.items() if os.path.exists(v)}, 'notes': getattr(self, 'project_notes', ''), 'creation_date': getattr(self, 'project_creation_date', datetime.now().isoformat()), 'last_modified': datetime.now().isoformat() } # Save project data with open(self.current_project_file, 'w') as f: json.dump(self.convert_to_serializable(project_data), f, indent=2) if show_message: self.show_info("Project Saved", f"Project saved to {self.current_project_file}") # Update the window title self.update_window_title() # Update image_paths to reflect the correct locations for file_name in self.image_paths.keys(): self.image_paths[file_name] = os.path.join(images_dir, file_name) def save_project_as(self): new_project_file, _ = QFileDialog.getSaveFileName(self, "Save Project As", "", "Image Annotator Project (*.iap)") if new_project_file: # Ensure the file has the correct extension if not new_project_file.lower().endswith('.iap'): new_project_file += '.iap' # Store the original project file original_project_file = getattr(self, 'current_project_file', None) # Set the new project file as the current one self.current_project_file = new_project_file self.current_project_dir = os.path.dirname(new_project_file) # Save the project with the new name self.save_project(show_message=False) # Update the window title self.update_window_title() # Show a success message QMessageBox.information(self, "Project Saved As", f"Project saved as:\n{new_project_file}") # If this was originally a new unsaved project, update the original project file if original_project_file is None: self.current_project_file = new_project_file def auto_save(self): if self.is_loading_project: return # Skip auto-save during project loading if not hasattr(self, 'current_project_file'): reply = QMessageBox.question(self, 'No Project', "You need to save the project before auto-saving. Would you like to save now?", QMessageBox.Yes | QMessageBox.No, QMessageBox.Yes) if reply == QMessageBox.Yes: self.save_project() else: return if hasattr(self, 'current_project_file'): self.save_project(show_message=False) print("Project auto-saved.") def show_project_details(self): if not hasattr(self, 'current_project_file'): QMessageBox.warning(self, "No Project", "Please open or create a project first.") return from .project_details import ProjectDetailsDialog from .annotation_statistics import AnnotationStatisticsDialog # Generate annotation statistics stats_dialog = AnnotationStatisticsDialog(self) stats_dialog.generate_statistics(self.all_annotations) dialog = ProjectDetailsDialog(self, stats_dialog) if dialog.exec_() == QDialog.Accepted: if dialog.were_changes_made(): self.project_notes = dialog.get_notes() self.save_project(show_message=False) QMessageBox.information(self, "Project Details", "Project details have been updated.") else: print("No changes made to project details.") def load_multi_slice_image(self, image_path, dimensions=None, shape=None): file_name = os.path.basename(image_path) base_name = os.path.splitext(file_name)[0] print(f"Loading multi-slice image: {image_path}") print(f"Base name: {base_name}") if dimensions and shape: print(f"Using stored dimensions: {dimensions}") print(f"Using stored shape: {shape}") self.image_dimensions[base_name] = dimensions self.image_shapes[base_name] = shape if image_path.lower().endswith(('.tif', '.tiff')): self.load_tiff(image_path, dimensions, shape) elif image_path.lower().endswith('.czi'): self.load_czi(image_path, dimensions, shape) else: print("No stored dimensions or shape, loading as new image") if image_path.lower().endswith(('.tif', '.tiff')): self.load_tiff(image_path) elif image_path.lower().endswith('.czi'): self.load_czi(image_path) print(f"Loaded multi-slice image: {file_name}") print(f"Dimensions: {self.image_dimensions.get(base_name, 'Not found')}") print(f"Shape: {self.image_shapes.get(base_name, 'Not found')}") print(f"Number of slices: {len(self.slices)}") if self.slices: self.current_image = self.slices[0][1] self.current_slice = self.slices[0][0] self.update_slice_list() self.slice_list.setCurrentRow(0) self.activate_slice(self.current_slice) print(f"Activated first slice: {self.current_slice}") else: print("No slices were loaded") self.current_image = None self.current_slice = None self.update_slice_list() self.image_label.update() # print(f"Loaded slices: {[slice_name for slice_name, _ in self.slices]}") def activate_sam_magic_wand(self): # Uncheck all other tools for button in self.tool_group.buttons(): if button != self.sam_magic_wand_button: button.setChecked(False) # Set the current tool self.image_label.current_tool = "sam_magic_wand" self.image_label.sam_magic_wand_active = True self.image_label.setCursor(Qt.CrossCursor) # Update UI based on the current tool self.update_ui_for_current_tool() # If a class is not selected, select the first one (if available) if self.current_class is None and self.class_list.count() > 0: self.class_list.setCurrentRow(0) self.current_class = self.class_list.currentItem().text() elif self.class_list.count() == 0: QMessageBox.warning(self, "No Class Selected", "Please add a class before using annotation tools.") self.sam_magic_wand_button.setChecked(False) self.deactivate_sam_magic_wand() def deactivate_sam_magic_wand(self): self.image_label.current_tool = None self.image_label.sam_magic_wand_active = False self.sam_magic_wand_button.setChecked(False) self.sam_magic_wand_button.setEnabled(False) # Disable the button self.image_label.setCursor(Qt.ArrowCursor) # Clear any SAM-related temporary data self.image_label.sam_bbox = None self.image_label.drawing_sam_bbox = False self.image_label.temp_sam_prediction = None # Update UI based on the current tool self.update_ui_for_current_tool() def toggle_sam_assisted(self): if not self.current_sam_model: QMessageBox.warning(self, "No SAM Model Selected", "Please pick a SAM model before using the SAM-Assisted tool.") self.sam_magic_wand_button.setChecked(False) return if self.sam_magic_wand_button.isChecked(): self.activate_sam_magic_wand() else: self.deactivate_sam_magic_wand() self.image_label.clear_temp_sam_prediction() # Clear temporary prediction def toggle_sam_magic_wand(self): if self.sam_magic_wand_button.isChecked(): if self.current_class is None: QMessageBox.warning(self, "No Class Selected", "Please select a class before using SAM2 Magic Wand.") self.sam_magic_wand_button.setChecked(False) return self.image_label.setCursor(Qt.CrossCursor) self.image_label.sam_magic_wand_active = True else: self.image_label.setCursor(Qt.ArrowCursor) self.image_label.sam_magic_wand_active = False self.image_label.sam_bbox = None self.image_label.clear_temp_sam_prediction() # Clear temporary prediction def apply_sam_prediction(self): if self.image_label.sam_bbox is None: print("SAM bbox is None") return x1, y1, x2, y2 = self.image_label.sam_bbox bbox = [min(x1, x2), min(y1, y2), max(x1, x2), max(y1, y2)] print(f"Applying SAM prediction with bbox: {bbox}") prediction = self.sam_utils.apply_sam_prediction(self.current_image, bbox) if prediction: temp_annotation = { "segmentation": prediction["segmentation"], "category_id": self.class_mapping[self.current_class], "category_name": self.current_class, "score": prediction["score"] } self.image_label.temp_sam_prediction = temp_annotation self.image_label.update() else: print("Failed to generate prediction") # Reset SAM bounding box self.image_label.sam_bbox = None self.image_label.update() def accept_sam_prediction(self): if self.image_label.temp_sam_prediction: new_annotation = self.image_label.temp_sam_prediction self.image_label.annotations.setdefault(new_annotation["category_name"], []).append(new_annotation) self.add_annotation_to_list(new_annotation) self.save_current_annotations() self.update_slice_list_colors() self.image_label.temp_sam_prediction = None self.image_label.update() print("SAM prediction accepted and added to annotations.") def setup_slice_list(self): self.slice_list = QListWidget() self.slice_list.itemClicked.connect(self.switch_slice) self.image_list_layout.addWidget(QLabel("Slices:")) self.image_list_layout.addWidget(self.slice_list) def qimage_to_numpy(self, qimage): width = qimage.width() height = qimage.height() ptr = qimage.bits() ptr.setsize(height * width * 4) arr = np.frombuffer(ptr, np.uint8).reshape((height, width, 4)) return arr[:, :, :3] # Slice off the alpha channel def open_images(self): file_names, _ = QFileDialog.getOpenFileNames(self, "Open Images", "", "Image Files (*.png *.jpg *.bmp *.tif *.tiff *.czi)") if file_names: self.image_list.clear() self.image_paths.clear() self.all_images.clear() self.slice_list.clear() self.slices.clear() self.current_stack = None self.current_slice = None self.add_images_to_list(file_names) def convert_to_8bit_rgb(self, image_array): if image_array.ndim == 2: # Grayscale image image_8bit = self.normalize_array(image_array) return np.stack((image_8bit,) * 3, axis=-1) elif image_array.ndim == 3: if image_array.shape[2] == 3: # Already RGB, just normalize return self.normalize_array(image_array) elif image_array.shape[2] > 3: # Multi-channel image, use first three channels rgb_array = image_array[:, :, :3] return self.normalize_array(rgb_array) raise ValueError(f"Unsupported image shape: {image_array.shape}") def add_images_to_list(self, file_names): first_added_item = None for file_name in file_names: base_name = os.path.basename(file_name) if base_name not in self.image_paths: image_info = { "file_name": base_name, "height": 0, "width": 0, "id": len(self.all_images) + 1, "is_multi_slice": False } # Detect multi-slice images and set dimensions if file_name.lower().endswith(('.tif', '.tiff', '.czi')): self.load_multi_slice_image(file_name) base_name_without_ext = os.path.splitext(base_name)[0] if base_name_without_ext in self.image_slices and self.image_slices[base_name_without_ext]: first_slice_name, first_slice = self.image_slices[base_name_without_ext][0] image_info["height"] = first_slice.height() image_info["width"] = first_slice.width() image_info["is_multi_slice"] = True image_info["dimensions"] = self.image_dimensions.get(base_name_without_ext, []) image_info["shape"] = self.image_shapes.get(base_name_without_ext, []) else: # For regular images image = QImage(file_name) image_info["height"] = image.height() image_info["width"] = image.width() self.all_images.append(image_info) item = QListWidgetItem(base_name) self.image_list.addItem(item) if first_added_item is None: first_added_item = item # Update image_paths with the original file path self.image_paths[base_name] = file_name if first_added_item: self.image_list.setCurrentItem(first_added_item) self.switch_image(first_added_item) if not self.is_loading_project: self.auto_save() def update_all_images(self, new_image_info): for info in new_image_info: if not any(img['file_name'] == info['file_name'] for img in self.all_images): self.all_images.append(info) def closeEvent(self, event): if not self.image_label.check_unsaved_changes(): event.ignore() return event.accept() if self.image_label.temp_paint_mask is not None or self.image_label.temp_eraser_mask is not None: reply = QMessageBox.question(self, 'Unsaved Changes', "You have unsaved changes. Do you want to save them before closing?", QMessageBox.Yes | QMessageBox.No | QMessageBox.Cancel) if reply == QMessageBox.Yes: if self.image_label.temp_paint_mask is not None: self.image_label.commit_paint_annotation() if self.image_label.temp_eraser_mask is not None: self.image_label.commit_eraser_changes() elif reply == QMessageBox.Cancel: event.ignore() return # Perform any other cleanup or saving operations here event.accept() def switch_slice(self, item): if item is None: return if not self.image_label.check_unsaved_changes(): return # Check for unsaved changes if self.image_label.temp_paint_mask is not None or self.image_label.temp_eraser_mask is not None: reply = QMessageBox.question(self, 'Unsaved Changes', "You have unsaved changes. Do you want to save them?", QMessageBox.Yes | QMessageBox.No | QMessageBox.Cancel) if reply == QMessageBox.Yes: if self.image_label.temp_paint_mask is not None: self.image_label.commit_paint_annotation() if self.image_label.temp_eraser_mask is not None: self.image_label.commit_eraser_changes() elif reply == QMessageBox.Cancel: return else: self.image_label.discard_paint_annotation() self.image_label.discard_eraser_changes() self.save_current_annotations() self.image_label.clear_temp_sam_prediction() slice_name = item.text() for name, qimage in self.slices: if name == slice_name: self.current_image = qimage self.current_slice = name self.display_image() self.load_image_annotations() self.update_annotation_list() self.clear_highlighted_annotation() self.image_label.highlighted_annotations.clear() # Add this line self.image_label.reset_annotation_state() self.image_label.clear_current_annotation() self.update_image_info() break # Ensure the UI is updated self.image_label.update() self.update_slice_list_colors() # Reset zoom level to default (1.0) self.set_zoom(1.0) def switch_image(self, item): if item is None: return if not self.image_label.check_unsaved_changes(): return # Store the current item before checking temp annotations current_item = self.image_list.currentItem() if not self.check_temp_annotations(): # If the user chooses not to discard temp annotations, revert the selection self.image_list.setCurrentItem(current_item) return self.save_current_annotations() self.image_label.clear_temp_sam_prediction() self.image_label.exit_editing_mode() file_name = item.text() print(f"\nSwitching to image: {file_name}") image_info = next((img for img in self.all_images if img["file_name"] == file_name), None) if image_info: self.image_file_name = file_name image_path = self.image_paths.get(file_name) if not image_path: image_path = os.path.join(self.current_project_dir, "images", file_name) if image_path and os.path.exists(image_path): if image_info.get('is_multi_slice', False): base_name = os.path.splitext(file_name)[0] if base_name in self.image_slices: self.slices = self.image_slices[base_name] if self.slices: self.current_image = self.slices[0][1] self.current_slice = self.slices[0][0] self.update_slice_list() self.activate_slice(self.current_slice) else: self.load_multi_slice_image(image_path, image_info.get('dimensions'), image_info.get('shape')) else: self.load_regular_image(image_path) self.display_image() self.clear_slice_list() self.load_image_annotations() self.update_annotation_list() self.clear_highlighted_annotation() self.image_label.highlighted_annotations.clear() self.image_label.update() self.image_label.reset_annotation_state() self.image_label.clear_current_annotation() self.update_image_info() self.adjust_zoom_to_fit() else: self.current_image = None self.image_label.clear() self.load_image_annotations() self.update_annotation_list() self.update_image_info() self.image_list.setCurrentItem(item) self.image_label.update() self.update_slice_list_colors() else: self.current_image = None self.current_slice = None self.image_label.clear() self.update_image_info() self.clear_slice_list() def adjust_zoom_to_fit(self): if not self.current_image: return # Get the dimensions of the image and the display area image_width = self.current_image.width() image_height = self.current_image.height() display_width = self.scroll_area.viewport().width() display_height = self.scroll_area.viewport().height() # Calculate and apply the zoom factor to fit the longest side zoom_factor = min(display_width / image_width, display_height / image_height) self.set_zoom(zoom_factor) def activate_current_slice(self): if self.current_slice: # Ensure the current slice is selected in the slice list items = self.slice_list.findItems(self.current_slice, Qt.MatchExactly) if items: self.slice_list.setCurrentItem(items[0]) # Load annotations for the current slice self.load_image_annotations() # Update the image label self.image_label.update() # Update the annotation list self.update_annotation_list() def load_image(self, image_path): extension = os.path.splitext(image_path)[1].lower() if extension in ['.tif', '.tiff']: self.load_tiff(image_path) elif extension == '.czi': self.load_czi(image_path) else: self.load_regular_image(image_path) def load_tiff(self, image_path, dimensions=None, shape=None, force_dimension_dialog=False): print(f"Loading TIFF file: {image_path}") with TiffFile(image_path) as tif: print(f"TIFF tags: {tif.pages[0].tags}") # Try to access metadata if available try: metadata = tif.pages[0].tags['ImageDescription'].value print(f"TIFF metadata: {metadata}") except KeyError: print("No ImageDescription metadata found") # Check if it's a multi-page TIFF if len(tif.pages) > 1: print(f"Multi-page TIFF detected. Number of pages: {len(tif.pages)}") # Read all pages into a 3D array image_array = tif.asarray() else: print("Single-page TIFF detected.") image_array = tif.pages[0].asarray() print(f"Image array shape: {image_array.shape}") print(f"Image array dtype: {image_array.dtype}") print(f"Image min: {image_array.min()}, max: {image_array.max()}") if dimensions and shape and not force_dimension_dialog: # Use stored dimensions and shape print(f"Using stored dimensions: {dimensions}") print(f"Using stored shape: {shape}") image_array = image_array.reshape(shape) else: # Process as before for new images or when forcing dimension dialog print("Processing as new image or forcing dimension dialog.") dimensions = None self.process_multidimensional_image(image_array, image_path, dimensions, force_dimension_dialog) def load_czi(self, image_path, dimensions=None, shape=None, force_dimension_dialog=False): print(f"Loading CZI file: {image_path}") with CziFile(image_path) as czi: image_array = czi.asarray() print(f"CZI array shape: {image_array.shape}") print(f"CZI array dtype: {image_array.dtype}") print(f"CZI array min: {image_array.min()}, max: {image_array.max()}") if dimensions and shape and not force_dimension_dialog: # Use stored dimensions and shape print(f"Using stored dimensions: {dimensions}") print(f"Using stored shape: {shape}") image_array = image_array.reshape(shape) else: # Process as before for new images or when forcing dimension dialog print("Processing as new image or forcing dimension dialog.") dimensions = None self.process_multidimensional_image(image_array, image_path, dimensions, force_dimension_dialog) def load_regular_image(self, image_path): self.current_image = QImage(image_path) self.slices = [] self.slice_list.clear() self.current_slice = None def process_multidimensional_image(self, image_array, image_path, dimensions=None, force_dimension_dialog=False): file_name = os.path.basename(image_path) base_name = os.path.splitext(file_name)[0] print(f"Processing file: {file_name}") print(f"Image array shape: {image_array.shape}") print(f"Image array dtype: {image_array.dtype}") if dimensions is None or force_dimension_dialog: if image_array.ndim > 2: default_dimensions = ['Z', 'H', 'W'] if image_array.ndim == 3 else ['T', 'Z', 'H', 'W'] default_dimensions = default_dimensions[-image_array.ndim:] # Show a progress dialog progress = QProgressDialog("Assigning dimensions...", "Cancel", 0, 100, self) progress.setWindowModality(Qt.WindowModal) progress.setMinimumDuration(0) progress.setValue(10) QApplication.processEvents() while True: dialog = DimensionDialog(image_array.shape, file_name, self, default_dimensions) dialog.setWindowFlags(dialog.windowFlags() & ~Qt.WindowContextHelpButtonHint) progress.setValue(50) QApplication.processEvents() if dialog.exec_(): dimensions = dialog.get_dimensions() print(f"Assigned dimensions: {dimensions}") if 'H' in dimensions and 'W' in dimensions: self.image_dimensions[base_name] = dimensions break else: QMessageBox.warning(self, "Invalid Dimensions", "You must assign both H and W dimensions.") else: progress.close() return progress.setValue(100) progress.close() else: dimensions = ['H', 'W'] self.image_dimensions[base_name] = dimensions self.image_shapes[base_name] = image_array.shape print(f"Final assigned dimensions: {self.image_dimensions[base_name]}") print(f"Image shape: {self.image_shapes[base_name]}") if self.image_dimensions[base_name]: self.create_slices(image_array, self.image_dimensions[base_name], image_path) else: rgb_image = self.convert_to_8bit_rgb(image_array) self.current_image = self.array_to_qimage(rgb_image) self.slices = [] self.slice_list.clear() if self.slices: self.current_image = self.slices[0][1] self.current_slice = self.slices[0][0] self.slice_list.setCurrentRow(0) self.load_image_annotations() self.image_label.update() self.update_image_info() # Update UI self.update_slice_list() self.update_annotation_list() self.image_label.update() def create_slices(self, image_array, dimensions, image_path): base_name = os.path.splitext(os.path.basename(image_path))[0] slices = [] self.slice_list.clear() print(f"Creating slices for {base_name}") print(f"Dimensions: {dimensions}") print(f"Image array shape: {image_array.shape}") # Create and show progress dialog progress = QProgressDialog("Loading slices...", "Cancel", 0, 100, self) progress.setWindowModality(Qt.WindowModal) progress.setMinimumDuration(0) # Show immediately # Handle 2D images if image_array.ndim == 2: progress.setValue(50) # Update progress QApplication.processEvents() # Allow GUI to update normalized_array = self.normalize_array(image_array) qimage = self.array_to_qimage(normalized_array) slice_name = f"{base_name}" slices.append((slice_name, qimage)) self.add_slice_to_list(slice_name) else: # For 3D or higher dimensional arrays slice_indices = [i for i, dim in enumerate(dimensions) if dim not in ['H', 'W']] total_slices = np.prod([image_array.shape[i] for i in slice_indices]) for idx, _ in enumerate(np.ndindex(tuple(image_array.shape[i] for i in slice_indices))): if progress.wasCanceled(): break full_idx = [slice(None)] * len(dimensions) for i, val in zip(slice_indices, _): full_idx[i] = val slice_array = image_array[tuple(full_idx)] rgb_slice = self.convert_to_8bit_rgb(slice_array) qimage = self.array_to_qimage(rgb_slice) slice_name = f"{base_name}_{'_'.join([f'{dimensions[i]}{val+1}' for i, val in zip(slice_indices, _)])}" slices.append((slice_name, qimage)) self.add_slice_to_list(slice_name) # Update progress progress_value = int((idx + 1) / total_slices * 100) progress.setValue(progress_value) QApplication.processEvents() # Allow GUI to update progress.setValue(100) # Ensure progress reaches 100% self.image_slices[base_name] = slices self.slices = slices if slices: self.current_image = slices[0][1] self.current_slice = slices[0][0] self.slice_list.setCurrentRow(0) self.activate_slice(self.current_slice) slice_info = f"Total slices: {len(slices)}" for dim, size in zip(dimensions, image_array.shape): if dim not in ['H', 'W']: slice_info += f", {dim}: {size}" self.update_image_info(additional_info=slice_info) else: print("No slices were created") print(f"Created {len(slices)} slices for {base_name}") return slices def add_slice_to_list(self, slice_name): item = QListWidgetItem(slice_name) if self.dark_mode: # Dark mode item.setBackground(QColor(40, 40, 40)) # Very dark gray background for all items if slice_name in self.all_annotations: item.setForeground(QColor(60, 60, 60)) # Dark gray text item.setBackground(QColor(173, 216, 230)) # Light blue background else: item.setForeground(QColor(200, 200, 200)) # Light gray text else: # Light mode item.setBackground(QColor(240, 240, 240)) # Very light gray background for all items if slice_name in self.all_annotations: item.setForeground(QColor(255, 255, 255)) # White text item.setBackground(QColor(70, 130, 180)) # Medium-dark blue background else: item.setForeground(QColor(0, 0, 0)) # Black text self.slice_list.addItem(item) def normalize_array(self, array): # print(f"Normalizing array. Shape: {array.shape}, dtype: {array.dtype}") # print(f"Array min: {array.min()}, max: {array.max()}, mean: {array.mean()}") array_float = array.astype(np.float32) if array.dtype == np.uint16: array_normalized = (array_float - array.min()) / (array.max() - array.min()) elif array.dtype == np.uint8: # For 8-bit images, use a simple contrast stretching p_low, p_high = np.percentile(array_float, (0, 100)) #Change these to 1, 99 or something to stretch the contrast for visualizing 8 bit images array_normalized = np.clip(array_float, p_low, p_high) array_normalized = (array_normalized - p_low) / (p_high - p_low) else: array_normalized = (array_float - array.min()) / (array.max() - array.min()) # Apply gamma correction gamma = 1.0 # Adjust this value to fine-tune brightness (> 1 for darker, < 1 for brighter) array_normalized = np.power(array_normalized, gamma) return (array_normalized * 255).astype(np.uint8) def adjust_contrast(self, image, low_percentile=1, high_percentile=99): if image.dtype != np.uint8: p_low, p_high = np.percentile(image, (low_percentile, high_percentile)) image_adjusted = np.clip(image, p_low, p_high) image_adjusted = (image_adjusted - p_low) / (p_high - p_low) return (image_adjusted * 255).astype(np.uint8) return image def activate_slice(self, slice_name): self.current_slice = slice_name self.image_file_name = slice_name self.load_image_annotations() self.update_annotation_list() for name, qimage in self.slices: if name == slice_name: self.current_image = qimage self.display_image() break self.image_label.update() items = self.slice_list.findItems(slice_name, Qt.MatchExactly) if items: self.slice_list.setCurrentItem(items[0]) def array_to_qimage(self, array): if array.ndim == 2: height, width = array.shape return QImage(array.data, width, height, width, QImage.Format_Grayscale8) elif array.ndim == 3 and array.shape[2] == 3: height, width, _ = array.shape bytes_per_line = 3 * width return QImage(array.data, width, height, bytes_per_line, QImage.Format_RGB888) else: raise ValueError(f"Unsupported array shape {array.shape} for conversion to QImage") def update_slice_list(self): self.slice_list.clear() for slice_name, _ in self.slices: item = QListWidgetItem(slice_name) if slice_name in self.all_annotations: item.setForeground(QColor(Qt.green)) else: item.setForeground(QColor(Qt.black) if not self.dark_mode else QColor(Qt.white)) self.slice_list.addItem(item) # Select the current slice if self.current_slice: items = self.slice_list.findItems(self.current_slice, Qt.MatchExactly) if items: self.slice_list.setCurrentItem(items[0]) def clear_slice_list(self): self.slice_list.clear() self.slices = [] self.current_slice = None def reset_tool_buttons(self): for button in self.tool_group.buttons(): button.setChecked(False) def keyPressEvent(self, event): # Check if the current focus is on a text editing widget focused_widget = QApplication.focusWidget() if isinstance(focused_widget, (QLineEdit, QTextEdit)): super().keyPressEvent(event) return if event.key() == Qt.Key_F2: self.launch_snake_game() elif event.key() == Qt.Key_Delete: # Handle deletions if self.class_list.hasFocus() and self.class_list.currentItem(): self.delete_class(self.class_list.currentItem()) elif self.annotation_list.hasFocus() and self.annotation_list.selectedItems(): self.delete_selected_annotations() elif self.image_list.hasFocus() and self.image_list.currentItem(): self.delete_selected_image() elif event.key() == Qt.Key_Up or event.key() == Qt.Key_Down: # Handle slice navigation if self.slice_list.hasFocus(): current_row = self.slice_list.currentRow() if event.key() == Qt.Key_Up and current_row > 0: self.slice_list.setCurrentRow(current_row - 1) elif event.key() == Qt.Key_Down and current_row < self.slice_list.count() - 1: self.slice_list.setCurrentRow(current_row + 1) self.switch_slice(self.slice_list.currentItem()) else: # Pass the event to the parent for default handling super().keyPressEvent(event) elif event.key() == Qt.Key_Return or event.key() == Qt.Key_Enter: # Handle accepting visible temporary classes if self.has_visible_temp_classes(): self.accept_visible_temp_classes() else: super().keyPressEvent(event) elif event.key() == Qt.Key_Escape: # Handle rejecting visible temporary classes if self.has_visible_temp_classes(): self.reject_visible_temp_classes() else: super().keyPressEvent(event) else: # Pass any other key events to the parent for default handling super().keyPressEvent(event) def has_visible_temp_classes(self): for i in range(self.class_list.count()): item = self.class_list.item(i) if item.text().startswith("Temp-") and item.checkState() == Qt.Checked: return True return False def launch_snake_game(self): #print("Launching Snake game") if not hasattr(self, 'snake_game') or not self.snake_game.isVisible(): self.snake_game = SnakeGame() self.snake_game.show() self.snake_game.setFocus() def import_annotations(self): if not self.image_label.check_unsaved_changes(): return print("Starting import_annotations") import_format = self.import_format_selector.currentText() print(f"Import format: {import_format}") if import_format == "COCO JSON": file_name, _ = QFileDialog.getOpenFileName(self, "Import COCO JSON Annotations", "", "JSON Files (*.json)") if not file_name: print("No file selected, returning") return print(f"Selected file: {file_name}") json_dir = os.path.dirname(file_name) images_dir = os.path.join(json_dir, 'images') imported_annotations, image_info = import_coco_json(file_name, self.class_mapping) elif import_format in ["YOLO (v4 and earlier)", "YOLO (v5+)"]: yaml_file, _ = QFileDialog.getOpenFileName(self, "Select YOLO Dataset YAML", "", "YAML Files (*.yaml *.yml)") if not yaml_file: print("No YAML file selected, returning") return print(f"Selected YAML file: {yaml_file}") try: imported_annotations, image_info = process_import_format(import_format, yaml_file, self.class_mapping) yaml_dir = os.path.dirname(yaml_file) if import_format == "YOLO (v4 and earlier)": images_dir = os.path.join(yaml_dir, 'train', 'images') else: # YOLO (v5+) images_dir = os.path.join(yaml_dir, 'images', 'train') # Preferring train over val except ValueError as e: QMessageBox.warning(self, "Import Error", str(e)) return else: QMessageBox.warning(self, "Unsupported Format", f"The selected format '{import_format}' is not implemented for import.") return print(f"JSON/YOLO directory: {json_dir if import_format == 'COCO JSON' else os.path.dirname(yaml_file)}") print(f"Images directory: {images_dir}") print(f"Imported annotations count: {len(imported_annotations)}") print(f"Image info count: {len(image_info)}") images_loaded = 0 images_not_found = [] for info in image_info.values(): print(f"Processing image: {info['file_name']}") image_path = os.path.join(images_dir, info['file_name']) if os.path.exists(image_path): print(f"Image found at: {image_path}") self.image_paths[info['file_name']] = image_path self.all_images.append({ "file_name": info['file_name'], "height": info['height'], "width": info['width'], "id": info['id'], "is_multi_slice": False }) images_loaded += 1 else: print(f"Image not found at: {image_path}") images_not_found.append(info['file_name']) print(f"Images loaded: {images_loaded}") print(f"Images not found: {len(images_not_found)}") if images_not_found: message = f"The following {len(images_not_found)} images were not found in the 'images' directory:\n\n" message += "\n".join(images_not_found[:10]) if len(images_not_found) > 10: message += f"\n... and {len(images_not_found) - 10} more." message += "\n\nDo you want to proceed and ignore annotations for these missing images?" reply = QMessageBox.question(self, "Missing Images", message, QMessageBox.Yes | QMessageBox.No, QMessageBox.No) if reply == QMessageBox.No: print("Import cancelled due to missing images") QMessageBox.information(self, "Import Cancelled", "Import cancelled. Please ensure all images are in the 'images' directory and try again.") return # Update annotations (only for found images) for image_name, annotations in imported_annotations.items(): if image_name not in self.image_paths: continue self.all_annotations[image_name] = {} for category_name, category_annotations in annotations.items(): self.all_annotations[image_name][category_name] = [] for i, ann in enumerate(category_annotations, start=1): new_ann = { "segmentation": ann.get("segmentation"), "bbox": ann.get("bbox"), "category_id": ann["category_id"], "category_name": category_name, "number": i, "type": ann.get("type", "polygon") } self.all_annotations[image_name][category_name].append(new_ann) # Update class mapping and colors for annotations in self.all_annotations.values(): for category_name in annotations.keys(): if category_name not in self.class_mapping: new_id = len(self.class_mapping) + 1 self.class_mapping[category_name] = new_id self.image_label.class_colors[category_name] = QColor(Qt.GlobalColor(new_id % 16 + 7)) print("Updating UI") # Update UI self.update_class_list() self.update_image_list() self.update_annotation_list() # Highlight and display the first image if self.image_list.count() > 0: self.image_list.setCurrentRow(0) self.switch_image(self.image_list.item(0)) # Select the first class if available if self.class_list.count() > 0: self.class_list.setCurrentRow(0) self.on_class_selected() self.image_label.update() message = f"Annotations have been imported successfully from {file_name if import_format == 'COCO JSON' else yaml_file}.\n" message += f"{images_loaded} images were loaded from the 'images' directory.\n" if images_not_found: message += f"Annotations for {len(images_not_found)} missing images were ignored." print("Import complete, showing message") QMessageBox.information(self, "Import Complete", message) self.auto_save() # Auto-save after importing annotations def export_annotations(self): if not self.image_label.check_unsaved_changes(): return export_format = self.export_format_selector.currentText() supported_formats = [ "COCO JSON", "YOLO (v4 and earlier)", "YOLO (v5+)", "Labeled Images", "Semantic Labels", "Pascal VOC (BBox)", "Pascal VOC (BBox + Segmentation)" ] if export_format not in supported_formats: QMessageBox.warning(self, "Unsupported Format", f"The selected format '{export_format}' is not implemented.") return if export_format == "COCO JSON": file_name, _ = QFileDialog.getSaveFileName(self, "Export COCO JSON Annotations", "", "JSON Files (*.json)") else: file_name = QFileDialog.getExistingDirectory(self, f"Select Output Directory for {export_format} Export") if not file_name: return self.save_current_annotations() if export_format == "COCO JSON": output_dir = os.path.dirname(file_name) json_filename = os.path.basename(file_name) json_file, images_dir = export_coco_json(self.all_annotations, self.class_mapping, self.image_paths, self.slices, self.image_slices, output_dir, json_filename) message = "Annotations have been exported successfully in COCO JSON format.\n" message += f"JSON file: {json_file}\nImages directory: {images_dir}" elif export_format == "YOLO (v4 and earlier)": labels_dir, yaml_path = export_yolo_v4(self.all_annotations, self.class_mapping, self.image_paths, self.slices, self.image_slices, file_name) message = "Annotations have been exported successfully in YOLO (v4 and earlier) format.\n" message += f"Labels: {labels_dir}\nYAML: {yaml_path}" elif export_format == "YOLO (v5+)": output_dir, yaml_path = export_yolo_v5plus(self.all_annotations, self.class_mapping, self.image_paths, self.slices, self.image_slices, file_name) message = "Annotations have been exported successfully in YOLO (v5+) format.\n" message += f"Output directory: {output_dir}\nYAML: {yaml_path}" elif export_format == "Labeled Images": labeled_images_dir = export_labeled_images(self.all_annotations, self.class_mapping, self.image_paths, self.slices, self.image_slices, file_name) message = f"Labeled images have been exported successfully.\nLabeled Images: {labeled_images_dir}\n" message += f"A class summary has been saved in: {os.path.join(labeled_images_dir, 'class_summary.txt')}" elif export_format == "Semantic Labels": semantic_labels_dir = export_semantic_labels(self.all_annotations, self.class_mapping, self.image_paths, self.slices, self.image_slices, file_name) message = f"Semantic labels have been exported successfully.\nSemantic Labels: {semantic_labels_dir}\n" message += f"A class-pixel mapping has been saved in: {os.path.join(semantic_labels_dir, 'class_pixel_mapping.txt')}" elif export_format == "Pascal VOC (BBox)": voc_dir = export_pascal_voc_bbox(self.all_annotations, self.class_mapping, self.image_paths, self.slices, self.image_slices, file_name) message = "Annotations have been exported successfully in Pascal VOC format (BBox only).\n" message += f"Pascal VOC Annotations: {voc_dir}" elif export_format == "Pascal VOC (BBox + Segmentation)": voc_dir = export_pascal_voc_both(self.all_annotations, self.class_mapping, self.image_paths, self.slices, self.image_slices, file_name) message = "Annotations have been exported successfully in Pascal VOC format (BBox + Segmentation).\n" message += f"Pascal VOC Annotations: {voc_dir}" QMessageBox.information(self, "Export Complete", message) def save_slices(self, directory): slices_saved = False for image_file, image_slices in self.image_slices.items(): for slice_name, qimage in image_slices: if slice_name in self.all_annotations and self.all_annotations[slice_name]: file_path = os.path.join(directory, f"{slice_name}.png") qimage.save(file_path, "PNG") slices_saved = True return slices_saved def create_coco_annotation(self, ann, image_id, annotation_id): coco_ann = { "id": annotation_id, "image_id": image_id, "category_id": ann["category_id"], "area": calculate_area(ann), "iscrowd": 0 } if "segmentation" in ann: coco_ann["segmentation"] = [ann["segmentation"]] coco_ann["bbox"] = calculate_bbox(ann["segmentation"]) elif "bbox" in ann: coco_ann["bbox"] = ann["bbox"] return coco_ann def update_all_annotation_lists(self): for image_name in self.all_annotations.keys(): self.update_annotation_list(image_name) self.update_annotation_list() # Update for the current image/slice def update_annotation_list(self, image_name=None): self.annotation_list.clear() current_name = image_name or self.current_slice or self.image_file_name annotations = self.all_annotations.get(current_name, {}) for class_name, class_annotations in annotations.items(): if not class_name.startswith("Temp-"): # Only show non-temporary annotations color = self.image_label.class_colors.get(class_name, QColor(Qt.white)) for annotation in class_annotations: number = annotation.get('number', 0) area = calculate_area(annotation) item_text = f"{class_name} - {number:<3} Area: {area:.2f}" item = QListWidgetItem(item_text) item.setData(Qt.UserRole, annotation) item.setForeground(color) self.annotation_list.addItem(item) # Force the annotation list to repaint self.annotation_list.repaint() def update_slice_list_colors(self): # Set the background color of the entire list widget if self.dark_mode: self.slice_list.setStyleSheet("QListWidget { background-color: rgb(40, 40, 40); }") else: self.slice_list.setStyleSheet("QListWidget { background-color: rgb(240, 240, 240); }") for i in range(self.slice_list.count()): item = self.slice_list.item(i) slice_name = item.text() if self.dark_mode: # Dark mode if slice_name in self.all_annotations and any(self.all_annotations[slice_name].values()): item.setForeground(QColor(60, 60, 60)) # Dark gray text item.setBackground(QColor(173, 216, 230)) # Light blue background else: item.setForeground(QColor(200, 200, 200)) # Light gray text item.setBackground(QColor(40, 40, 40)) # Very dark gray background else: # Light mode if slice_name in self.all_annotations and any(self.all_annotations[slice_name].values()): item.setForeground(QColor(255, 255, 255)) # White text item.setBackground(QColor(70, 130, 180)) # Medium-dark blue background else: item.setForeground(QColor(0, 0, 0)) # Black text item.setBackground(QColor(240, 240, 240)) # Very light gray background # Force the list to repaint self.slice_list.repaint() def update_annotation_list_colors(self, class_name=None, color=None): for i in range(self.annotation_list.count()): item = self.annotation_list.item(i) annotation = item.data(Qt.UserRole) # Update only the item for the specific class if class_name is provided if class_name is None or annotation['category_name'] == class_name: item_color = color if class_name else self.image_label.class_colors.get(annotation['category_name'], QColor(Qt.white)) item.setForeground(item_color) def load_image_annotations(self): #print(f"Loading annotations for: {self.current_slice or self.image_file_name}") self.image_label.annotations.clear() current_name = self.current_slice or self.image_file_name #print(f"Current name for annotations: {current_name}") #print(f"All annotations keys: {list(self.all_annotations.keys())}") if current_name in self.all_annotations: self.image_label.annotations = copy.deepcopy(self.all_annotations[current_name]) #print(f"Loaded annotations: {self.image_label.annotations}") else: print(f"No annotations found for {current_name}") self.image_label.update() def save_current_annotations(self): if self.current_slice: current_name = self.current_slice elif self.image_file_name: current_name = self.image_file_name else: #print("Error: No current slice or image file name set") return #print(f"Saving annotations for: {current_name}") if self.image_label.annotations: self.all_annotations[current_name] = self.image_label.annotations.copy() #print(f"Saved {len(self.image_label.annotations)} annotations for {current_name}") elif current_name in self.all_annotations: del self.all_annotations[current_name] #print(f"Removed annotations for {current_name}") self.update_slice_list_colors() #print(f"All annotations now: {self.all_annotations.keys()}") #print(f"Current slice: {self.current_slice}") #print(f"Current image_file_name: {self.image_file_name}") def setup_class_list(self): """Set up the class list widget.""" self.class_list = QListWidget() self.class_list.setContextMenuPolicy(Qt.CustomContextMenu) self.class_list.customContextMenuRequested.connect(self.show_class_context_menu) self.class_list.itemClicked.connect(self.on_class_selected) self.sidebar_layout.addWidget(QLabel("Classes:")) self.sidebar_layout.addWidget(self.class_list) def setup_tool_buttons(self): """Set up the tool buttons with grouped manual and automated tools.""" self.tool_group = QButtonGroup(self) self.tool_group.setExclusive(False) # Create a widget for manual tools manual_tools_widget = QWidget() manual_layout = QVBoxLayout(manual_tools_widget) manual_layout.setSpacing(5) manual_label = QLabel("Manual Tools") manual_label.setAlignment(Qt.AlignCenter) manual_layout.addWidget(manual_label) manual_buttons_layout = QHBoxLayout() self.polygon_button = QPushButton("Polygon") self.polygon_button.setCheckable(True) self.rectangle_button = QPushButton("Rectangle") self.rectangle_button.setCheckable(True) manual_buttons_layout.addWidget(self.polygon_button) manual_buttons_layout.addWidget(self.rectangle_button) manual_layout.addLayout(manual_buttons_layout) self.tool_group.addButton(self.polygon_button) self.tool_group.addButton(self.rectangle_button) self.polygon_button.clicked.connect(self.toggle_tool) self.rectangle_button.clicked.connect(self.toggle_tool) # Create a widget for automated tools automated_tools_widget = QWidget() automated_layout = QVBoxLayout(automated_tools_widget) automated_layout.setSpacing(5) automated_label = QLabel("Automated Tools") automated_label.setAlignment(Qt.AlignCenter) automated_layout.addWidget(automated_label) automated_buttons_layout = QHBoxLayout() self.sam_magic_wand_button = QPushButton("Magic Wand") self.sam_magic_wand_button.setCheckable(True) automated_buttons_layout.addWidget(self.sam_magic_wand_button) automated_layout.addLayout(automated_buttons_layout) self.tool_group.addButton(self.sam_magic_wand_button) self.sam_magic_wand_button.clicked.connect(self.toggle_tool) # Add the grouped tools to the sidebar layout self.sidebar_layout.addWidget(manual_tools_widget) self.sidebar_layout.addWidget(automated_tools_widget) # Set a fixed size for all buttons to make them smaller for button in [self.polygon_button, self.rectangle_button, self.load_sam2_button, self.sam_magic_wand_button]: button.setFixedSize(100, 30) def setup_annotation_list(self): """Set up the annotation list widget.""" self.annotation_list = QListWidget() self.annotation_list.setSelectionMode(QAbstractItemView.ExtendedSelection) self.annotation_list.itemSelectionChanged.connect(self.update_highlighted_annotations) def create_menu_bar(self): menu_bar = self.menuBar() # Project Menu project_menu = menu_bar.addMenu("&Project") new_project_action = QAction("&New Project", self) new_project_action.setShortcut(QKeySequence.New) new_project_action.triggered.connect(self.new_project) project_menu.addAction(new_project_action) open_project_action = QAction("&Open Project", self) open_project_action.setShortcut(QKeySequence.Open) open_project_action.triggered.connect(self.open_project) project_menu.addAction(open_project_action) save_project_action = QAction("&Save Project", self) save_project_action.setShortcut(QKeySequence.Save) save_project_action.triggered.connect(self.save_project) project_menu.addAction(save_project_action) save_project_as_action = QAction("Save Project &As...", self) save_project_as_action.setShortcut(QKeySequence("Ctrl+Shift+S")) save_project_as_action.triggered.connect(self.save_project_as) project_menu.addAction(save_project_as_action) close_project_action = QAction("&Close Project", self) close_project_action.setShortcut(QKeySequence("Ctrl+W")) close_project_action.triggered.connect(self.close_project) project_menu.addAction(close_project_action) project_details_action = QAction("Project &Details", self) project_details_action.setShortcut(QKeySequence("Ctrl+I")) project_details_action.triggered.connect(self.show_project_details) project_menu.addAction(project_details_action) search_projects_action = QAction("&Search Projects", self) search_projects_action.setShortcut(QKeySequence("Ctrl+F")) search_projects_action.triggered.connect(self.show_project_search) project_menu.addAction(search_projects_action) # Settings Menu settings_menu = menu_bar.addMenu("&Settings") font_size_menu = settings_menu.addMenu("&Font Size") for size in ["Small", "Medium", "Large", "XL", "XXL"]: action = QAction(size, self) action.triggered.connect(lambda checked, s=size: self.change_font_size(s)) font_size_menu.addAction(action) toggle_dark_mode_action = QAction("Toggle &Dark Mode", self) toggle_dark_mode_action.setShortcut(QKeySequence("Ctrl+D")) toggle_dark_mode_action.triggered.connect(self.toggle_dark_mode) settings_menu.addAction(toggle_dark_mode_action) # Tools Menu tools_menu = menu_bar.addMenu("&Tools") annotation_stats_action = QAction("Annotation Statistics", self) annotation_stats_action.triggered.connect(self.show_annotation_statistics) annotation_stats_action.setShortcut(QKeySequence("Ctrl+Alt+S")) tools_menu.addAction(annotation_stats_action) coco_json_combiner_action = QAction("COCO JSON Combiner", self) coco_json_combiner_action.triggered.connect(self.show_coco_json_combiner) tools_menu.addAction(coco_json_combiner_action) dataset_splitter_action = QAction("Dataset Splitter", self) dataset_splitter_action.triggered.connect(self.open_dataset_splitter) tools_menu.addAction(dataset_splitter_action) stack_to_slices_action = QAction("Stack to Slices", self) stack_to_slices_action.triggered.connect(self.show_stack_to_slices) tools_menu.addAction(stack_to_slices_action) image_patcher_action = QAction("Image Patcher", self) image_patcher_action.triggered.connect(self.show_image_patcher) tools_menu.addAction(image_patcher_action) image_augmenter_action = QAction("Image Augmenter", self) image_augmenter_action.triggered.connect(self.show_image_augmenter) tools_menu.addAction(image_augmenter_action) slice_registration_action = QAction("Slice Registration", self) slice_registration_action.triggered.connect(self.show_slice_registration) tools_menu.addAction(slice_registration_action) stack_interpolator_action = QAction("Stack Interpolator", self) stack_interpolator_action.triggered.connect(self.show_stack_interpolator) tools_menu.addAction(stack_interpolator_action) dicom_converter_action = QAction("DICOM Converter", self) dicom_converter_action.triggered.connect(self.show_dicom_converter) tools_menu.addAction(dicom_converter_action) # Help Menu help_menu = menu_bar.addMenu("&Help") help_action = QAction("&Show Help", self) help_action.setShortcut(QKeySequence.HelpContents) help_action.triggered.connect(self.show_help) help_menu.addAction(help_action) def change_font_size(self, size): self.current_font_size = size self.apply_theme_and_font() def setup_sidebar(self): self.sidebar = QWidget() self.sidebar_layout = QVBoxLayout(self.sidebar) self.layout.addWidget(self.sidebar, 1) # Helper function to create section headers def create_section_header(text): label = QLabel(text) label.setProperty("class", "section-header") label.setAlignment(Qt.AlignLeft) return label # Import functionality self.import_button = QPushButton("Import Annotations with Images") self.import_button.clicked.connect(self.import_annotations) self.sidebar_layout.addWidget(self.import_button) self.import_format_selector = QComboBox() self.import_format_selector.addItem("COCO JSON") self.import_format_selector.addItem("YOLO (v4 and earlier)") # Modified name self.import_format_selector.addItem("YOLO (v5+)") # New format self.sidebar_layout.addWidget(self.import_format_selector) # Add spacing self.sidebar_layout.addSpacing(20) self.add_images_button = QPushButton("Add New Images") self.add_images_button.clicked.connect(self.add_images) self.sidebar_layout.addWidget(self.add_images_button) self.add_class_button = QPushButton("Add Classes") self.add_class_button.clicked.connect(lambda: self.add_class()) self.sidebar_layout.addWidget(self.add_class_button) # Class list (without the "Classes" header) self.class_list = QListWidget() self.class_list.setContextMenuPolicy(Qt.CustomContextMenu) self.class_list.customContextMenuRequested.connect(self.show_class_context_menu) self.class_list.itemClicked.connect(self.on_class_selected) self.sidebar_layout.addWidget(self.class_list) button_layout_class_list = QHBoxLayout() self.clrButton =QPushButton(self.class_list) self.clrButton.setText("clear all") self.clrButton.setEnabled(False) self.allButton = QPushButton(self.class_list) self.allButton.setText("select all") self.allButton.setEnabled(False) button_layout_class_list.addWidget(self.clrButton) button_layout_class_list.addWidget(self.allButton) self.clrButton.clicked.connect(lambda : self.toggle_all_class(Qt.Unchecked)) self.allButton.clicked.connect(lambda : self.toggle_all_class(Qt.Checked)) self.sidebar_layout.addLayout(button_layout_class_list) # Annotation section self.sidebar_layout.addWidget(create_section_header("Annotation")) annotation_widget = QWidget() annotation_layout = QVBoxLayout(annotation_widget) # Manual tools subsection manual_widget = QWidget() manual_layout = QVBoxLayout(manual_widget) button_layout_top = QHBoxLayout() self.polygon_button = QPushButton("Polygon") self.polygon_button.setCheckable(True) self.rectangle_button = QPushButton("Rectangle") self.rectangle_button.setCheckable(True) button_layout_top.addWidget(self.polygon_button) button_layout_top.addWidget(self.rectangle_button) button_layout_bottom = QHBoxLayout() self.paint_brush_button = QPushButton("Paint Brush") self.paint_brush_button.setCheckable(True) self.eraser_button = QPushButton("Eraser") self.eraser_button.setCheckable(True) button_layout_bottom.addWidget(self.paint_brush_button) button_layout_bottom.addWidget(self.eraser_button) manual_layout.addLayout(button_layout_top) manual_layout.addLayout(button_layout_bottom) annotation_layout.addWidget(manual_widget) # SAM-Assisted tools subsection sam_widget = QWidget() sam_layout = QVBoxLayout(sam_widget) # SAM-Assisted button on top self.sam_magic_wand_button = QPushButton("SAM-Assisted") self.sam_magic_wand_button.setCheckable(True) self.sam_magic_wand_button.clicked.connect(self.toggle_sam_assisted) sam_layout.addWidget(self.sam_magic_wand_button) # Add SAM model selector self.sam_model_selector = QComboBox() self.sam_model_selector.addItem("Pick a SAM Model") self.sam_model_selector.addItems(list(self.sam_utils.sam_models.keys())) self.sam_model_selector.currentTextChanged.connect(self.change_sam_model) sam_layout.addWidget(self.sam_model_selector) annotation_layout.addWidget(sam_widget) # Setup tool group self.tool_group = QButtonGroup(self) self.tool_group.setExclusive(False) self.tool_group.addButton(self.polygon_button) self.tool_group.addButton(self.rectangle_button) self.tool_group.addButton(self.paint_brush_button) self.tool_group.addButton(self.eraser_button) self.tool_group.addButton(self.sam_magic_wand_button) self.polygon_button.clicked.connect(self.toggle_tool) self.rectangle_button.clicked.connect(self.toggle_tool) self.paint_brush_button.clicked.connect(self.toggle_tool) self.eraser_button.clicked.connect(self.toggle_tool) self.sam_magic_wand_button.clicked.connect(self.toggle_tool) # Annotations list subsection annotation_layout.addWidget(QLabel("Annotations")) self.annotation_list = QListWidget() self.annotation_list.setSelectionMode(QAbstractItemView.ExtendedSelection) self.annotation_list.itemSelectionChanged.connect(self.update_highlighted_annotations) annotation_layout.addWidget(self.annotation_list) # Create a horizontal layout for the sort buttons sort_button_layout = QHBoxLayout() self.sort_by_class_button = QPushButton("Sort by Class") self.sort_by_class_button.clicked.connect(self.sort_annotations_by_class) sort_button_layout.addWidget(self.sort_by_class_button) self.sort_by_area_button = QPushButton("Sort by Area") self.sort_by_area_button.clicked.connect(self.sort_annotations_by_area) sort_button_layout.addWidget(self.sort_by_area_button) # Add the sort button layout to the annotation layout annotation_layout.addLayout(sort_button_layout) # Delete and Merge annotation buttons self.delete_button = QPushButton("Delete") self.delete_button.clicked.connect(self.delete_selected_annotations) self.merge_button = QPushButton("Merge") self.merge_button.clicked.connect(self.merge_annotations) self.change_class_button = QPushButton("Change Class") self.change_class_button.clicked.connect(self.change_annotation_class) # Create a horizontal layout for the other buttons button_layout = QHBoxLayout() button_layout.addWidget(self.delete_button) button_layout.addWidget(self.merge_button) button_layout.addWidget(self.change_class_button) # Add the button layout to the annotation layout annotation_layout.addLayout(button_layout) # Add export format selector self.export_format_selector = QComboBox() self.export_format_selector.addItem("COCO JSON") self.export_format_selector.addItem("YOLO (v4 and earlier)") # Modified name self.export_format_selector.addItem("YOLO (v5+)") # New format self.export_format_selector.addItem("Labeled Images") self.export_format_selector.addItem("Semantic Labels") self.export_format_selector.addItem("Pascal VOC (BBox)") self.export_format_selector.addItem("Pascal VOC (BBox + Segmentation)") annotation_layout.addWidget(QLabel("Export Format:")) annotation_layout.addWidget(self.export_format_selector) self.export_button = QPushButton("Export Annotations") self.export_button.clicked.connect(self.export_annotations) annotation_layout.addWidget(self.export_button) # Add the annotation widget to the sidebar self.sidebar_layout.addWidget(annotation_widget) def sort_annotations_by_class(self): current_name = self.current_slice or self.image_file_name if current_name not in self.all_annotations: QMessageBox.information(self, "No Annotations", "There are no annotations to sort for this image.") return annotations = self.all_annotations[current_name] sorted_annotations = [] for class_name in sorted(annotations.keys()): if not class_name.startswith("Temp-"): # Skip temporary classes class_annotations = sorted(annotations[class_name], key=lambda x: x.get('number', 0)) sorted_annotations.extend(class_annotations) self.update_annotation_list_with_sorted(sorted_annotations) def sort_annotations_by_area(self): current_name = self.current_slice or self.image_file_name if current_name not in self.all_annotations: QMessageBox.information(self, "No Annotations", "There are no annotations to sort for this image.") return annotations = self.all_annotations[current_name] sorted_annotations = [] for class_name in annotations.keys(): if not class_name.startswith("Temp-"): # Skip temporary classes class_annotations = sorted(annotations[class_name], key=lambda x: calculate_area(x), reverse=True) sorted_annotations.extend(class_annotations) self.update_annotation_list_with_sorted(sorted_annotations) def update_annotation_list_with_sorted(self, sorted_annotations): self.annotation_list.clear() for annotation in sorted_annotations: class_name = annotation['category_name'] if not class_name.startswith("Temp-"): # Only add non-temporary annotations number = annotation.get('number', 0) area = calculate_area(annotation) item_text = f"{class_name} - {number:<3} Area: {area:.2f}" item = QListWidgetItem(item_text) item.setData(Qt.UserRole, annotation) color = self.image_label.class_colors.get(class_name, QColor(Qt.white)) item.setForeground(color) self.annotation_list.addItem(item) self.image_label.update() def change_sam_model(self, model_name): self.sam_utils.change_sam_model(model_name) self.current_sam_model = self.sam_utils.current_sam_model if model_name != "Pick a SAM Model": # Enable the SAM Magic Wand button self.sam_magic_wand_button.setEnabled(True) # Activate the SAM Magic Wand tool self.sam_magic_wand_button.setChecked(True) self.activate_sam_magic_wand() print(f"Changed SAM model to: {model_name}") else: # Disable and deactivate the SAM Magic Wand button self.sam_magic_wand_button.setEnabled(False) self.sam_magic_wand_button.setChecked(False) self.deactivate_sam_magic_wand() print("SAM model unset") def setup_font_size_selector(self): font_size_label = QLabel("Font Size:") self.font_size_selector = QComboBox() self.font_size_selector.addItems(["Small", "Medium", "Large"]) self.font_size_selector.setCurrentText("Medium") self.font_size_selector.currentTextChanged.connect(self.on_font_size_changed) self.sidebar_layout.addWidget(font_size_label) self.sidebar_layout.addWidget(self.font_size_selector) def on_font_size_changed(self, size): self.current_font_size = size self.apply_theme_and_font() def apply_theme_and_font(self): font_size = self.font_sizes[self.current_font_size] if self.dark_mode: style = soft_dark_stylesheet else: style = default_stylesheet # Combine the theme stylesheet with font size combined_style = f"{style}\nQWidget {{ font-size: {font_size}pt; }}" self.setStyleSheet(combined_style) # Apply font size to all widgets for widget in self.findChildren(QWidget): font = widget.font() font.setPointSize(font_size) widget.setFont(font) self.image_label.setFont(QFont("Arial", font_size)) self.update() def toggle_dark_mode(self): self.dark_mode = not self.dark_mode self.apply_theme_and_font() # Update slice list colors self.update_slice_list_colors() # Update other UI elements if necessary self.update_class_list() self.update_annotation_list() # Force a repaint of the main window self.repaint() def apply_stylesheet(self): if self.dark_mode: self.setStyleSheet(soft_dark_stylesheet) else: self.setStyleSheet(default_stylesheet) def update_ui_colors(self): # Update colors for elements that need to retain their functionality self.update_annotation_list_colors() self.update_slice_list_colors() self.image_label.update() def setup_image_area(self): """Set up the main image area.""" self.image_widget = QWidget() self.image_layout = QVBoxLayout(self.image_widget) self.layout.addWidget(self.image_widget, 3) self.scroll_area = QScrollArea() self.scroll_area.setWidgetResizable(True) self.scroll_area.setHorizontalScrollBarPolicy(Qt.ScrollBarAsNeeded) self.scroll_area.setVerticalScrollBarPolicy(Qt.ScrollBarAsNeeded) # Use the already initialized image_label self.image_label.setAlignment(Qt.AlignCenter) self.scroll_area.setWidget(self.image_label) self.image_layout.addWidget(self.scroll_area) self.zoom_slider = QSlider(Qt.Horizontal) self.zoom_slider.setMinimum(10) self.zoom_slider.setMaximum(500) self.zoom_slider.setValue(100) self.zoom_slider.setTickPosition(QSlider.TicksBelow) self.zoom_slider.setTickInterval(50) self.zoom_slider.valueChanged.connect(self.zoom_image) self.image_layout.addWidget(self.zoom_slider) self.image_info_label = QLabel() self.image_layout.addWidget(self.image_info_label) def setup_image_list(self): """Set up the image list area.""" self.image_list_widget = QWidget() self.image_list_layout = QVBoxLayout(self.image_list_widget) self.layout.addWidget(self.image_list_widget, 1) self.image_list_label = QLabel("Images:") self.image_list_layout.addWidget(self.image_list_label) self.image_list = QListWidget() self.image_list.itemClicked.connect(self.switch_image) self.image_list.currentRowChanged.connect(lambda row: self.switch_image(self.image_list.currentItem())) self.image_list.setContextMenuPolicy(Qt.CustomContextMenu) self.image_list.customContextMenuRequested.connect(self.show_image_context_menu) self.image_list_layout.addWidget(self.image_list) self.clear_all_button = QPushButton("Clear All Images and Annotations") self.clear_all_button.clicked.connect(self.clear_all) self.image_list_layout.addWidget(self.clear_all_button) ########## ### Tools ########## I love useful image processing tools :) def open_dataset_splitter(self): self.dataset_splitter = DatasetSplitterTool(self) self.dataset_splitter.setWindowModality(Qt.ApplicationModal) self.dataset_splitter.show_centered(self) def show_annotation_statistics(self): if not self.all_annotations: QMessageBox.warning(self, "No Annotations", "There are no annotations to analyze.") return try: self.annotation_stats_dialog = show_annotation_statistics(self, self.all_annotations) except Exception as e: QMessageBox.critical(self, "Error", f"An error occurred while showing annotation statistics: {str(e)}") def show_coco_json_combiner(self): self.coco_json_combiner_dialog = show_coco_json_combiner(self) def show_stack_to_slices(self): self.stack_to_slices_dialog = show_stack_to_slices(self) def show_image_patcher(self): self.image_patcher_dialog = show_image_patcher(self) def show_image_augmenter(self): self.image_augmenter_dialog = show_image_augmenter(self) def show_slice_registration(self): self.slice_registration_dialog = SliceRegistrationTool(self) self.slice_registration_dialog.show_centered(self) def show_stack_interpolator(self): self.stack_interpolator_dialog = StackInterpolator(self) self.stack_interpolator_dialog.show_centered(self) def show_dicom_converter(self): self.dicom_converter_dialog = DicomConverter(self) self.dicom_converter_dialog.show_centered(self) ################################################################### # update the show_help method: def show_help(self): self.help_window = HelpWindow(dark_mode=self.dark_mode, font_size=self.font_sizes[self.current_font_size]) self.help_window.show_centered(self) def add_images(self): if not self.image_label.check_unsaved_changes(): return file_names, _ = QFileDialog.getOpenFileNames(self, "Add Images", "", "Image Files (*.png *.jpg *.bmp *.tif *.tiff *.czi)") if file_names: self.add_images_to_list(file_names) def clear_all(self, new_project=False, show_messages=True): if not new_project and show_messages: reply = self.show_question('Clear All', "Are you sure you want to clear all images and annotations? This action cannot be undone.") if reply != QMessageBox.Yes: return # Clear images self.image_list.clear() self.image_paths.clear() self.all_images.clear() self.current_image = None self.image_file_name = "" # Clear the image display self.image_label.clear() self.image_label.setPixmap(QPixmap()) # Set an empty pixmap self.image_label.original_pixmap = None self.image_label.scaled_pixmap = None # Clear annotations self.all_annotations.clear() self.annotation_list.clear() self.image_label.annotations.clear() self.image_label.highlighted_annotations.clear() # Clear current class self.current_class = None # Reset class-related data self.class_list.clear() self.allButton.setEnabled(False) self.clrButton.setEnabled(False) self.image_label.class_colors.clear() self.class_mapping.clear() # Clear slices self.image_slices.clear() self.slices = [] self.slice_list.clear() self.current_slice = None self.current_stack = None # Reset zoom self.image_label.zoom_factor = 1.0 self.zoom_slider.setValue(100) # Reset tools self.image_label.current_tool = None self.polygon_button.setChecked(False) self.rectangle_button.setChecked(False) self.sam_magic_wand_button.setChecked(False) self.sam_magic_wand_button.setEnabled(False) # Disable the SAM-Assisted button self.image_label.sam_magic_wand_active = False # Deactivate SAM magic wand # Reset SAM-related attributes self.image_label.sam_bbox = None self.image_label.drawing_sam_bbox = False self.image_label.temp_sam_prediction = None self.image_label.setCursor(Qt.ArrowCursor) # Reset cursor to default self.sam_model_selector.setCurrentIndex(0) # Reset to "Pick a SAM Model" self.current_sam_model = None # Reset the current SAM model # Reset project-related attributes if not new_project: if hasattr(self, 'current_project_file'): del self.current_project_file if hasattr(self, 'current_project_dir'): del self.current_project_dir # Update UI self.image_label.update() self.update_image_info() # Force a repaint of the main window self.repaint() self.update_window_title() def show_warning(self, title, message): QMessageBox.warning(self, title, message) def show_info(self, title, message): QMessageBox.information(self, title, message) def update_image_info(self, additional_info=None): if self.current_image: width = self.current_image.width() height = self.current_image.height() info = f"Image: {width}x{height}" if additional_info: info += f", {additional_info}" self.image_info_label.setText(info) else: self.image_info_label.setText("No image loaded") def show_question(self, title, message): return QMessageBox.question(self, title, message, QMessageBox.Yes | QMessageBox.No, QMessageBox.No) def show_image_context_menu(self, position): menu = QMenu() current_item = self.image_list.itemAt(position) if current_item: file_name = current_item.text() delete_action = menu.addAction("Remove Image") if not self.is_multi_dimensional(file_name): predict_action = menu.addAction("Predict using YOLO") if self.is_multi_dimensional(file_name): redefine_dimensions_action = menu.addAction("Redefine Dimensions") action = menu.exec_(self.image_list.mapToGlobal(position)) if action == delete_action: self.remove_image() elif not self.is_multi_dimensional(file_name) and action == predict_action: self.predict_single_image(file_name) elif self.is_multi_dimensional(file_name) and action == redefine_dimensions_action: self.redefine_dimensions(file_name) def is_multi_dimensional(self, file_name): return file_name.lower().endswith(('.tif', '.tiff', '.czi')) def predict_single_image(self, file_name): if self.is_multi_dimensional(file_name): return # Do nothing for multi-dimensional images if not self.yolo_trainer or not self.yolo_trainer.model: QMessageBox.warning(self, "No Model", "Please load a YOLO model first from the YOLO > Prediction Settings > Load Model menu.") return # Deactivate SAM tool before prediction self.deactivate_sam_magic_wand() image_path = self.image_paths[file_name] try: results = self.yolo_trainer.predict(image_path) self.process_yolo_results(results, file_name) except Exception as e: QMessageBox.warning(self, "Prediction Error", f"An error occurred during prediction: {str(e)}\n\n" "This might be due to a mismatch between the model and the YAML file classes. " "Please check that the YAML file corresponds to the loaded model.") def redefine_dimensions(self, file_name): file_path = self.image_paths.get(file_name) if not file_path or not file_path.lower().endswith(('.tif', '.tiff', '.czi')): return # Exit the method if it's not a TIFF or CZI file reply = QMessageBox.warning(self, "Redefine Dimensions", "Redefining dimensions will cause all associated annotations to be lost. " "Do you want to continue?", QMessageBox.Yes | QMessageBox.No, QMessageBox.No) if reply == QMessageBox.Yes: # Remove existing annotations for this file base_name = os.path.splitext(file_name)[0] print(f"Removing annotations for image: {base_name}") #print(f"Current annotations: {list(self.all_annotations.keys())}") # Create a list of keys to remove, using a more specific matching condition keys_to_remove = [key for key in self.all_annotations.keys() if key == base_name or (key.startswith(f"{base_name}_") and not key.startswith(f"{base_name}_8bit"))] print(f"Keys to remove: {keys_to_remove}") # Remove the annotations for key in keys_to_remove: del self.all_annotations[key] #print(f"Annotations after removal: {list(self.all_annotations.keys())}") # Remove existing slices if base_name in self.image_slices: del self.image_slices[base_name] # Clear current image if it's the one being redefined if self.image_file_name == file_name: self.current_image = None self.image_label.clear() # Reload the image with new dimension dialog if file_path.lower().endswith(('.tif', '.tiff')): self.load_tiff(file_path, force_dimension_dialog=True) elif file_path.lower().endswith('.czi'): self.load_czi(file_path, force_dimension_dialog=True) # Update UI self.update_slice_list() self.update_annotation_list() self.image_label.update() #print(f"Final annotations: {list(self.all_annotations.keys())}") QMessageBox.information(self, "Dimensions Redefined", "The dimensions have been redefined and the image reloaded. " "All previous annotations for this image have been removed.") def remove_image(self): current_item = self.image_list.currentItem() if current_item: file_name = current_item.text() # Remove from all data structures self.image_list.takeItem(self.image_list.row(current_item)) self.image_paths.pop(file_name, None) self.all_images = [img for img in self.all_images if img["file_name"] != file_name] # Remove annotations self.all_annotations.pop(file_name, None) # Handle multi-dimensional images base_name = os.path.splitext(file_name)[0] if base_name in self.image_slices: # Remove slices for slice_name, _ in self.image_slices[base_name]: self.all_annotations.pop(slice_name, None) del self.image_slices[base_name] # Clear slice list self.slice_list.clear() # Clear current image and slice if it was the removed image if self.image_file_name == file_name: self.current_image = None self.image_file_name = "" self.current_slice = None self.image_label.clear() self.annotation_list.clear() # Switch to another image if available if self.image_list.count() > 0: next_item = self.image_list.item(0) self.image_list.setCurrentItem(next_item) self.switch_image(next_item) else: # No images left self.current_image = None self.image_file_name = "" self.current_slice = None self.image_label.clear() self.annotation_list.clear() self.slice_list.clear() # Update UI self.update_ui() self.auto_save() # Auto-save after removing an image def load_annotations(self): file_name, _ = QFileDialog.getOpenFileName(self, "Load Annotations", "", "JSON Files (*.json)") if file_name: with open(file_name, 'r') as f: self.loaded_json = json.load(f) # Load categories self.class_list.clear() self.image_label.class_colors.clear() self.class_mapping.clear() for category in self.loaded_json["categories"]: class_name = category["name"] self.class_mapping[class_name] = category["id"] # Assign a color if not already assigned if class_name not in self.image_label.class_colors: color = QColor(Qt.GlobalColor(len(self.image_label.class_colors) % 16 + 7)) self.image_label.class_colors[class_name] = color # Add item to class list with color indicator item = QListWidgetItem(class_name) self.update_class_item_color(item, self.image_label.class_colors[class_name]) self.class_list.addItem(item) # Create a mapping of image IDs to file names image_id_to_filename = {img["id"]: img["file_name"] for img in self.loaded_json["images"]} # Load image information json_images = {img["file_name"]: img for img in self.loaded_json["images"]} # Update existing images and add new ones from JSON updated_all_images = [] for i in range(self.image_list.count()): item = self.image_list.item(i) file_name = item.text() if file_name in json_images: updated_image = self.all_images[i].copy() updated_image.update(json_images[file_name]) updated_all_images.append(updated_image) del json_images[file_name] else: updated_all_images.append(self.all_images[i]) # Add remaining images from JSON for img in json_images.values(): updated_all_images.append(img) self.image_list.addItem(img["file_name"]) self.all_images = updated_all_images # Load annotations self.all_annotations.clear() for annotation in self.loaded_json["annotations"]: image_id = annotation["image_id"] file_name = image_id_to_filename.get(image_id) if file_name: if file_name not in self.all_annotations: self.all_annotations[file_name] = {} category = next((cat for cat in self.loaded_json["categories"] if cat["id"] == annotation["category_id"]), None) if category: category_name = category["name"] if category_name not in self.all_annotations[file_name]: self.all_annotations[file_name][category_name] = [] ann = { "category_id": annotation["category_id"], "category_name": category_name, } if "segmentation" in annotation: ann["segmentation"] = annotation["segmentation"][0] ann["type"] = "polygon" elif "bbox" in annotation: ann["bbox"] = annotation["bbox"] ann["type"] = "bbox" # Add number field if it's missing if "number" not in ann: ann["number"] = len(self.all_annotations[file_name][category_name]) + 1 self.all_annotations[file_name][category_name].append(ann) # Check for missing images missing_images = [img["file_name"] for img in self.loaded_json["images"] if img["file_name"] not in self.image_paths] if missing_images: self.show_warning("Missing Images", "The following images are missing:\n" + "\n".join(missing_images)) # Reload the current image if it exists, otherwise load the first image if self.image_file_name and self.image_file_name in self.all_annotations: self.switch_image(self.image_list.findItems(self.image_file_name, Qt.MatchExactly)[0]) elif self.all_images: self.switch_image(self.image_list.item(0)) self.image_label.highlighted_annotations = [] # Clear existing highlights self.update_annotation_list() # This will repopulate the annotation list self.image_label.update() # Force a redraw of the image label if self.class_list.count() > 0: self.allButton.setEnabled(True) self.clrButton.setEnabled(True) def clear_highlighted_annotation(self): self.image_label.highlighted_annotation = None self.image_label.update() def update_highlighted_annotations(self): selected_items = self.annotation_list.selectedItems() self.image_label.highlighted_annotations = [item.data(Qt.UserRole) for item in selected_items] self.image_label.update() # Force a redraw of the image label # Enable/disable merge and change class buttons based on selection self.merge_button.setEnabled(len(selected_items) >= 2) self.change_class_button.setEnabled(len(selected_items) > 0) def renumber_annotations(self): current_name = self.current_slice or self.image_file_name if current_name in self.all_annotations: for class_name, annotations in self.all_annotations[current_name].items(): for i, ann in enumerate(annotations, start=1): ann['number'] = i self.update_annotation_list() def delete_selected_annotations(self): selected_items = self.annotation_list.selectedItems() if not selected_items: QMessageBox.warning(self, "No Selection", "Please select an annotation to delete.") return reply = QMessageBox.question(self, 'Delete Annotations', f"Are you sure you want to delete {len(selected_items)} annotation(s)?", QMessageBox.Yes | QMessageBox.No, QMessageBox.No) if reply == QMessageBox.Yes: # Create a list of annotations to remove annotations_to_remove = [] for item in selected_items: annotation = item.data(Qt.UserRole) annotations_to_remove.append((annotation['category_name'], annotation)) # Remove annotations from image_label.annotations for category_name, annotation in annotations_to_remove: if category_name in self.image_label.annotations: if annotation in self.image_label.annotations[category_name]: self.image_label.annotations[category_name].remove(annotation) # Update all_annotations current_name = self.current_slice or self.image_file_name self.all_annotations[current_name] = self.image_label.annotations # Sort and update the annotation list based on the current sorting method if self.current_sort_method == "area": self.sort_annotations_by_area() else: self.sort_annotations_by_class() self.image_label.highlighted_annotations.clear() self.image_label.update() # Update slice list colors self.update_slice_list_colors() QMessageBox.information(self, "Annotations Deleted", f"{len(selected_items)} annotation(s) have been deleted.") self.auto_save() # Auto-save after deleting annotations def merge_annotations(self): if self.image_label.editing_polygon is not None: QMessageBox.warning(self, "Edit Mode Active", "Please exit the annotation edit mode before merging annotations.") return selected_items = self.annotation_list.selectedItems() if len(selected_items) < 2: QMessageBox.warning(self, "Not Enough Annotations", "Please select at least two annotations to merge.") return class_name = selected_items[0].data(Qt.UserRole)['category_name'] if not all(item.data(Qt.UserRole)['category_name'] == class_name for item in selected_items): QMessageBox.warning(self, "Mixed Classes", "All selected annotations must be from the same class.") return polygons = [] original_annotations = [] for item in selected_items: annotation = item.data(Qt.UserRole) original_annotations.append(annotation) if 'segmentation' in annotation: points = zip(annotation['segmentation'][0::2], annotation['segmentation'][1::2]) polygon = Polygon(points) if not polygon.is_valid: polygon = polygon.buffer(0) polygons.append(polygon) def are_all_polygons_connected(polygons): if len(polygons) < 2: return True connected = set([0]) # Start with the first polygon to_check = set(range(1, len(polygons))) while to_check: newly_connected = set() for i in connected: for j in to_check: if polygons[i].intersects(polygons[j]) or polygons[i].touches(polygons[j]): newly_connected.add(j) if not newly_connected: return False # If no new connections found, they're not all connected connected.update(newly_connected) to_check -= newly_connected return True # All polygons are connected if not are_all_polygons_connected(polygons): QMessageBox.warning(self, "Disconnected Polygons", "Not all selected annotations are connected. Please select only connected annotations to merge.") return try: merged_polygon = unary_union(polygons) except Exception as e: QMessageBox.warning(self, "Merge Error", f"Unable to merge the selected annotations due to an error: {str(e)}") return new_annotation = { "segmentation": [], "category_id": self.class_mapping[class_name], "category_name": class_name, } if isinstance(merged_polygon, Polygon): new_annotation["segmentation"] = [coord for point in merged_polygon.exterior.coords for coord in point] elif isinstance(merged_polygon, MultiPolygon): largest_polygon = max(merged_polygon.geoms, key=lambda p: p.area) new_annotation["segmentation"] = [coord for point in largest_polygon.exterior.coords for coord in point] # Ask user about keeping original annotations msg_box = QMessageBox(self) msg_box.setWindowTitle("Merge Annotations") msg_box.setText("Do you want to keep the original annotations?") msg_box.setIcon(QMessageBox.Question) keep_button = msg_box.addButton("Keep", QMessageBox.YesRole) delete_button = msg_box.addButton("Delete", QMessageBox.NoRole) cancel_button = msg_box.addButton("Cancel", QMessageBox.RejectRole) msg_box.setDefaultButton(cancel_button) msg_box.setEscapeButton(cancel_button) msg_box.exec_() if msg_box.clickedButton() == cancel_button: return if msg_box.clickedButton() == delete_button: for annotation in original_annotations: if annotation in self.image_label.annotations[class_name]: self.image_label.annotations[class_name].remove(annotation) self.image_label.annotations.setdefault(class_name, []).append(new_annotation) current_name = self.current_slice or self.image_file_name self.all_annotations[current_name] = self.image_label.annotations self.renumber_annotations() self.update_annotation_list() self.save_current_annotations() self.update_slice_list_colors() self.image_label.update() QMessageBox.information(self, "Merge Complete", "Annotations have been merged successfully.") self.auto_save() # Auto-save after merging annotations def delete_selected_image(self): current_item = self.image_list.currentItem() if current_item: file_name = current_item.text() reply = QMessageBox.question(self, 'Delete Image', f"Are you sure you want to delete the image '{file_name}'?\n\n" "This will remove the image and all its associated annotations.", QMessageBox.Yes | QMessageBox.No, QMessageBox.No) if reply == QMessageBox.Yes: # Remove from all data structures self.image_list.takeItem(self.image_list.row(current_item)) self.image_paths.pop(file_name, None) self.all_images = [img for img in self.all_images if img["file_name"] != file_name] # Remove annotations self.all_annotations.pop(file_name, None) # Handle multi-dimensional images base_name = os.path.splitext(file_name)[0] if base_name in self.image_slices: # Remove slices for slice_name, _ in self.image_slices[base_name]: self.all_annotations.pop(slice_name, None) del self.image_slices[base_name] # Clear slice list self.slice_list.clear() # Clear current image and slice if it was the removed image if self.image_file_name == file_name: self.current_image = None self.image_file_name = "" self.current_slice = None self.image_label.clear() self.annotation_list.clear() # Switch to another image if available if self.image_list.count() > 0: next_item = self.image_list.item(0) self.image_list.setCurrentItem(next_item) self.switch_image(next_item) else: # No images left self.current_image = None self.image_file_name = "" self.current_slice = None self.image_label.clear() self.annotation_list.clear() self.slice_list.clear() # Update UI self.update_ui() QMessageBox.information(self, "Image Deleted", f"The image '{file_name}' has been deleted.") def display_image(self): if self.current_image: if isinstance(self.current_image, QImage): pixmap = QPixmap.fromImage(self.current_image) elif isinstance(self.current_image, QPixmap): pixmap = self.current_image else: print(f"Unexpected image type: {type(self.current_image)}") return if not pixmap.isNull(): self.image_label.setPixmap(pixmap) self.image_label.adjustSize() else: print("Error: Null pixmap") else: self.image_label.clear() print("No current image to display") def update_ui(self): self.update_image_list() self.update_slice_list() self.update_class_list() self.update_annotation_list() self.image_label.update() self.update_image_info() def add_class(self, class_name=None, color=None): if not self.image_label.check_unsaved_changes(): return if class_name is None: while True: class_name, ok = QInputDialog.getText(self, "Add Class", "Enter class name:") if not ok: print("Class addition cancelled") return if not class_name.strip(): QMessageBox.warning(self, "Invalid Input", "Please enter a class name or press Cancel.") continue if class_name in self.class_mapping: QMessageBox.warning(self, "Duplicate Class", f"The class '{class_name}' already exists. Please choose a different name.") continue break else: # For programmatic addition (e.g., from YOLO predictions) if class_name in self.class_mapping: print(f"Class '{class_name}' already exists. Skipping addition.") return if not isinstance(class_name, str): print(f"Warning: class_name is not a string. Converting {class_name} to string.") class_name = str(class_name) if color is None: color = QColor(Qt.GlobalColor(len(self.image_label.class_colors) % 16 + 7)) elif isinstance(color, str): color = QColor(color) print(f"Adding class: {class_name}, color: {color.name()}") self.image_label.class_colors[class_name] = color self.class_mapping[class_name] = len(self.class_mapping) + 1 try: item = QListWidgetItem(class_name) # Create a color indicator pixmap = QPixmap(16, 16) pixmap.fill(color) item.setIcon(QIcon(pixmap)) # Set visibility state item.setData(Qt.UserRole, True) # Set checkbox item.setFlags(item.flags() | Qt.ItemIsUserCheckable) item.setCheckState(Qt.Checked) self.class_list.addItem(item) self.class_list.setCurrentItem(item) self.current_class = class_name print(f"Class added successfully: {class_name}") if not self.is_loading_project: self.auto_save() except Exception as e: print(f"Error adding class: {e}") import traceback traceback.print_exc() if self.class_list.count() > 0: self.allButton.setEnabled(True) self.clrButton.setEnabled(True) def update_class_item_color(self, item, color): pixmap = QPixmap(16, 16) pixmap.fill(color) item.setIcon(QIcon(pixmap)) def update_class_list(self): self.class_list.clear() for class_name, color in self.image_label.class_colors.items(): item = QListWidgetItem(class_name) # Create a color indicator pixmap = QPixmap(16, 16) pixmap.fill(color) item.setIcon(QIcon(pixmap)) # Store the visibility state item.setData(Qt.UserRole, self.image_label.class_visibility.get(class_name, True)) # Set checkbox item.setFlags(item.flags() | Qt.ItemIsUserCheckable) item.setCheckState(Qt.Checked if item.data(Qt.UserRole) else Qt.Unchecked) self.class_list.addItem(item) # Re-select the current class if it exists if self.current_class: items = self.class_list.findItems(self.current_class, Qt.MatchExactly) if items: self.class_list.setCurrentItem(items[0]) elif self.class_list.count() > 0: # If no class is selected, select the first one self.class_list.setCurrentItem(self.class_list.item(0)) if self.class_list.count() > 0: self.allButton.setEnabled(True) self.clrButton.setEnabled(True) print(f"Updated class list with {self.class_list.count()} items") def update_class_selection(self): for i in range(self.class_list.count()): item = self.class_list.item(i) if item.text() == self.current_class: item.setSelected(True) else: item.setSelected(False) def toggle_class_visibility(self, item): class_name = item.text() is_visible = item.checkState() == Qt.Checked self.image_label.set_class_visibility(class_name, is_visible) item.setData(Qt.UserRole, is_visible) self.image_label.update() def toggle_all_class(self, checked): for i in range(self.class_list.count()): item = self.class_list.item(i) item.setCheckState(checked) # Update image annotations self.image_label.update() def change_annotation_class(self): selected_items = self.annotation_list.selectedItems() if not selected_items: QMessageBox.warning(self, "No Selection", "Please select one or more annotations to change class.") return class_dialog = QDialog(self) class_dialog.setWindowTitle("Change Class") layout = QVBoxLayout(class_dialog) class_combo = QComboBox() for class_name in self.class_mapping.keys(): class_combo.addItem(class_name) layout.addWidget(class_combo) button_box = QDialogButtonBox(QDialogButtonBox.Ok | QDialogButtonBox.Cancel) button_box.accepted.connect(class_dialog.accept) button_box.rejected.connect(class_dialog.reject) layout.addWidget(button_box) if class_dialog.exec_() == QDialog.Accepted: new_class = class_combo.currentText() current_name = self.current_slice or self.image_file_name # Get the current maximum number for the new class max_number = max([ann.get('number', 0) for ann in self.image_label.annotations.get(new_class, [])] + [0]) for item in selected_items: annotation = item.data(Qt.UserRole) old_class = annotation['category_name'] # Remove from old class self.image_label.annotations[old_class].remove(annotation) if not self.image_label.annotations[old_class]: del self.image_label.annotations[old_class] # Add to new class with updated number annotation['category_name'] = new_class annotation['category_id'] = self.class_mapping[new_class] max_number += 1 annotation['number'] = max_number if new_class not in self.image_label.annotations: self.image_label.annotations[new_class] = [] self.image_label.annotations[new_class].append(annotation) # Update all_annotations self.all_annotations[current_name] = self.image_label.annotations # Renumber all annotations for consistency self.renumber_annotations() self.update_annotation_list() self.image_label.update() self.save_current_annotations() self.update_slice_list_colors() self.auto_save() QMessageBox.information(self, "Class Changed", f"Selected annotations have been changed to class '{new_class}'.") def toggle_tool(self): if not self.image_label.check_unsaved_changes(): return sender = self.sender() if sender is None: sender = self.sam_magic_wand_button if not self.current_class: QMessageBox.warning(self, "No Class Selected", "Please select a class before using annotation tools.") sender.setChecked(False) return if self.current_class and self.current_class.startswith("Temp-"): QMessageBox.warning(self, "Invalid Selection", "Cannot use annotation tools with temporary classes.") sender.setChecked(False) return other_buttons = [btn for btn in self.tool_group.buttons() if btn != sender] # Deactivate SAM if we're switching to a different tool if sender != self.sam_magic_wand_button and self.image_label.sam_magic_wand_active: self.deactivate_sam_magic_wand() if sender.isChecked(): # Uncheck all other buttons for btn in other_buttons: btn.setChecked(False) # Set the current tool based on the checked button if sender == self.polygon_button: self.image_label.current_tool = "polygon" elif sender == self.rectangle_button: self.image_label.current_tool = "rectangle" elif sender == self.sam_magic_wand_button: self.image_label.current_tool = "sam_magic_wand" self.activate_sam_magic_wand() elif sender == self.paint_brush_button: self.image_label.current_tool = "paint_brush" self.image_label.setFocus() # Set focus on the image label elif sender == self.eraser_button: self.image_label.current_tool = "eraser" self.image_label.setFocus() # Set focus on the image label else: self.image_label.current_tool = None if sender == self.sam_magic_wand_button: self.deactivate_sam_magic_wand() # Update UI based on the current tool self.update_ui_for_current_tool() def wheelEvent(self, event): if event.modifiers() == Qt.ControlModifier: delta = event.angleDelta().y() if self.image_label.current_tool == "paint_brush": self.paint_brush_size = max(1, self.paint_brush_size + delta // 120) print(f"Paint brush size: {self.paint_brush_size}") elif self.image_label.current_tool == "eraser": self.eraser_size = max(1, self.eraser_size + delta // 120) print(f"Eraser size: {self.eraser_size}") else: super().wheelEvent(event) def update_ui_for_current_tool(self): # Disable finish_polygon_button if it still exists in your code if hasattr(self, 'finish_polygon_button'): self.finish_polygon_button.setEnabled(self.image_label.current_tool in ["polygon", "rectangle"]) # Update button states self.polygon_button.setChecked(self.image_label.current_tool == "polygon") self.rectangle_button.setChecked(self.image_label.current_tool == "rectangle") self.sam_magic_wand_button.setChecked(self.image_label.current_tool == "sam_magic_wand") # Enable/disable SAM button based on model availability self.sam_magic_wand_button.setEnabled(self.current_sam_model is not None) # Disable all tools if no class is selected tools_enabled = self.current_class is not None and not self.current_class.startswith("Temp-") for button in self.tool_group.buttons(): button.setEnabled(tools_enabled) # Update cursor based on the current tool if self.image_label.current_tool == "sam_magic_wand" and self.sam_magic_wand_button.isEnabled(): self.image_label.setCursor(Qt.CrossCursor) else: self.image_label.setCursor(Qt.ArrowCursor) def on_class_selected(self, current=None, previous=None): if not self.image_label.check_unsaved_changes(): return if current is None: current = self.class_list.currentItem() if current: self.current_class = current.text() print(f"Class selected: {self.current_class}") if self.current_class.startswith("Temp-"): self.disable_annotation_tools() else: self.enable_annotation_tools() else: self.current_class = None self.disable_annotation_tools() def disable_annotation_tools(self): for button in self.tool_group.buttons(): button.setChecked(False) button.setEnabled(False) self.image_label.current_tool = None def enable_annotation_tools(self): for button in self.tool_group.buttons(): button.setEnabled(True) def show_class_context_menu(self, position): menu = QMenu() rename_action = menu.addAction("Rename Class") change_color_action = menu.addAction("Change Color") delete_action = menu.addAction("Delete Class") item = self.class_list.itemAt(position) if item: action = menu.exec_(self.class_list.mapToGlobal(position)) if action == rename_action: self.rename_class(item) elif action == change_color_action: self.change_class_color(item) elif action == delete_action: self.delete_class(item) else: QMessageBox.warning(self, "No Selection", "Please select a class to perform actions.") def change_class_color(self, item): class_name = item.text() current_color = self.image_label.class_colors.get(class_name, QColor(Qt.white)) color = QColorDialog.getColor(current_color, self, f"Select Color for {class_name}") if color.isValid(): self.image_label.class_colors[class_name] = color # Update the color indicator pixmap = QPixmap(16, 16) pixmap.fill(color) item.setIcon(QIcon(pixmap)) self.update_annotation_list_colors(class_name, color) self.image_label.update() self.auto_save() # Auto-save after changing class color def rename_class(self, item): old_name = item.text() new_name, ok = QInputDialog.getText(self, "Rename Class", "Enter new class name:", text=old_name) if ok and new_name and new_name != old_name: # Update class mapping if old_name in self.class_mapping: old_id = self.class_mapping[old_name] self.class_mapping[new_name] = old_id del self.class_mapping[old_name] else: print(f"Warning: Class '{old_name}' not found in class_mapping") return # Update class colors if old_name in self.image_label.class_colors: self.image_label.class_colors[new_name] = self.image_label.class_colors.pop(old_name) else: print(f"Warning: Class '{old_name}' not found in class_colors") return # Update annotations for all images and slices for image_name, image_annotations in self.all_annotations.items(): if old_name in image_annotations: image_annotations[new_name] = image_annotations.pop(old_name) for annotation in image_annotations[new_name]: annotation['category_name'] = new_name # Update current image annotations if old_name in self.image_label.annotations: self.image_label.annotations[new_name] = self.image_label.annotations.pop(old_name) for annotation in self.image_label.annotations[new_name]: annotation['category_name'] = new_name # Update current class if it's the renamed one if self.current_class == old_name: self.current_class = new_name # Update annotation list for all images and slices self.update_all_annotation_lists() # Update class list item.setText(new_name) # Update the image label self.image_label.update() self.auto_save() # Auto-save after renaming a class print(f"Class renamed from '{old_name}' to '{new_name}'") def delete_class(self, item=None): if item is None: item = self.class_list.currentItem() if item is None: QMessageBox.warning(self, "No Selection", "Please select a class to delete.") return class_name = item.text() # Show confirmation dialog reply = QMessageBox.question(self, 'Delete Class', f"Are you sure you want to delete the class '{class_name}'?\n\n" "This will remove all annotations associated with this class.", QMessageBox.Yes | QMessageBox.No, QMessageBox.No) if reply == QMessageBox.Yes: # Proceed with deletion # Remove class color self.image_label.class_colors.pop(class_name, None) # Remove class from mapping self.class_mapping.pop(class_name, None) # Remove annotations for this class from all images for image_annotations in self.all_annotations.values(): image_annotations.pop(class_name, None) # Remove annotations for this class from current image self.image_label.annotations.pop(class_name, None) # Update annotation list self.update_annotation_list() # Remove class from list row = self.class_list.row(item) self.class_list.takeItem(row) # Update current_class if self.current_class == class_name: self.current_class = None if self.class_list.count() > 0: self.class_list.setCurrentRow(0) self.on_class_selected(self.class_list.item(0)) else: self.disable_annotation_tools() self.image_label.update() # Inform the user QMessageBox.information(self, "Class Deleted", f"The class '{class_name}' has been deleted.") self.auto_save() # Auto-save after deleting a class else: # User cancelled the operation QMessageBox.information(self, "Deletion Cancelled", "The class deletion was cancelled.") def finish_polygon(self): if self.image_label.current_tool == "polygon" and len(self.image_label.current_annotation) > 2: if self.current_class is None: QMessageBox.warning(self, "No Class Selected", "Please select a class before finishing the annotation.") return # Create a polygon from the current annotation polygon = Polygon(self.image_label.current_annotation) # Define the image boundary as a rectangle image_boundary = Polygon([(0, 0), (self.current_image.width(), 0), (self.current_image.width(), self.current_image.height()), (0, self.current_image.height())]) # Intersect the polygon with the image boundary clipped_polygon = polygon.intersection(image_boundary) if clipped_polygon.is_empty: QMessageBox.warning(self, "Invalid Annotation", "The annotation is completely outside the image boundaries.") self.image_label.clear_current_annotation() self.image_label.update() return # Convert the clipped polygon to a segmentation format if isinstance(clipped_polygon, Polygon): segmentation = [coord for point in clipped_polygon.exterior.coords for coord in point] elif isinstance(clipped_polygon, MultiPolygon): largest_polygon = max(clipped_polygon.geoms, key=lambda p: p.area) segmentation = [coord for point in largest_polygon.exterior.coords for coord in point] else: QMessageBox.warning(self, "Invalid Annotation", "The annotation could not be processed.") return new_annotation = { "segmentation": segmentation, "category_id": self.class_mapping[self.current_class], "category_name": self.current_class, } self.image_label.annotations.setdefault(self.current_class, []).append(new_annotation) self.add_annotation_to_list(new_annotation) self.image_label.clear_current_annotation() self.image_label.drawing_polygon = False # Reset the drawing_polygon flag self.image_label.reset_annotation_state() self.image_label.update() # Save the current annotations self.save_current_annotations() # Update the slice list colors self.update_slice_list_colors() self.auto_save() # Auto-save after adding a polygon annotation def highlight_annotation(self, item): self.image_label.highlighted_annotation = item.data(Qt.UserRole) self.image_label.update() def delete_annotation(self): current_item = self.annotation_list.currentItem() if current_item: annotation = current_item.data(Qt.UserRole) category_name = annotation['category_name'] self.image_label.annotations[category_name].remove(annotation) self.annotation_list.takeItem(self.annotation_list.row(current_item)) self.image_label.highlighted_annotation = None self.image_label.update() def add_annotation_to_list(self, annotation): class_name = annotation['category_name'] color = self.image_label.class_colors.get(class_name, QColor(Qt.white)) annotations = self.image_label.annotations.get(class_name, []) number = max([ann.get('number', 0) for ann in annotations] + [0]) + 1 annotation['number'] = number area = calculate_area(annotation) item_text = f"{class_name} - {number:<3} Area: {area:.2f}" item = QListWidgetItem(item_text) item.setData(Qt.UserRole, annotation) item.setForeground(color) self.annotation_list.addItem(item) # Clear the current selection self.annotation_list.clearSelection() self.image_label.highlighted_annotations.clear() self.image_label.update() def zoom_in(self): new_zoom = min(self.image_label.zoom_factor + 0.1, 5.0) self.set_zoom(new_zoom) def zoom_out(self): new_zoom = max(self.image_label.zoom_factor - 0.1, 0.1) self.set_zoom(new_zoom) def set_zoom(self, zoom_factor): self.image_label.set_zoom(zoom_factor) self.zoom_slider.setValue(int(zoom_factor * 100)) self.image_label.update() def zoom_image(self): zoom_factor = self.zoom_slider.value() / 100 self.set_zoom(zoom_factor) def disable_tools(self): self.polygon_button.setEnabled(False) self.rectangle_button.setEnabled(False) #self.finish_polygon_button.setEnabled(False) def enable_tools(self): self.polygon_button.setEnabled(True) self.rectangle_button.setEnabled(True) def finish_rectangle(self): if self.image_label.current_rectangle: x1, y1, x2, y2 = self.image_label.current_rectangle # Create a rectangle polygon from the annotation rectangle = Polygon([(x1, y1), (x2, y1), (x2, y2), (x1, y2)]) # Define the image boundary as a rectangle image_boundary = Polygon([(0, 0), (self.current_image.width(), 0), (self.current_image.width(), self.current_image.height()), (0, self.current_image.height())]) # Intersect the rectangle with the image boundary clipped_rectangle = rectangle.intersection(image_boundary) if clipped_rectangle.is_empty: QMessageBox.warning(self, "Invalid Annotation", "The annotation is completely outside the image boundaries.") self.image_label.current_rectangle = None self.image_label.update() return # Convert the clipped rectangle to a segmentation format if isinstance(clipped_rectangle, Polygon): segmentation = [coord for point in clipped_rectangle.exterior.coords for coord in point] elif isinstance(clipped_rectangle, MultiPolygon): largest_polygon = max(clipped_rectangle.geoms, key=lambda p: p.area) segmentation = [coord for point in largest_polygon.exterior.coords for coord in point] else: QMessageBox.warning(self, "Invalid Annotation", "The annotation could not be processed.") return new_annotation = { "segmentation": segmentation, "category_id": self.class_mapping[self.current_class], "category_name": self.current_class, } self.image_label.annotations.setdefault(self.current_class, []).append(new_annotation) self.add_annotation_to_list(new_annotation) self.image_label.start_point = None self.image_label.end_point = None self.image_label.current_rectangle = None self.image_label.update() # Save the current annotations self.save_current_annotations() # Update the slice list colors self.update_slice_list_colors() self.auto_save() def enter_edit_mode(self, annotation): self.editing_mode = True self.disable_tools() QMessageBox.information(self, "Edit Mode", "You are now in edit mode. Click and drag points to move them, Shift+Click to delete points, or click on edges to add new points.") def exit_edit_mode(self): self.editing_mode = False self.enable_tools() self.image_label.editing_polygon = None self.image_label.editing_point_index = None self.image_label.hover_point_index = None self.update_annotation_list() self.image_label.update() def highlight_annotation_in_list(self, annotation): for i in range(self.annotation_list.count()): item = self.annotation_list.item(i) if item.data(Qt.UserRole) == annotation: self.annotation_list.setCurrentItem(item) break def select_annotation_in_list(self, annotation): for i in range(self.annotation_list.count()): item = self.annotation_list.item(i) if item.data(Qt.UserRole) == annotation: self.annotation_list.setCurrentItem(item) break ################################################################ def setup_yolo_menu(self): yolo_menu = self.menuBar().addMenu("&YOLO (beta)") # Training submenu training_submenu = yolo_menu.addMenu("Training") load_pretrained_action = QAction("Load Pre-trained Model", self) load_pretrained_action.triggered.connect(self.load_yolo_model) training_submenu.addAction(load_pretrained_action) prepare_data_action = QAction("Prepare YOLO Dataset", self) prepare_data_action.triggered.connect(self.prepare_yolo_dataset) training_submenu.addAction(prepare_data_action) load_yaml_action = QAction("Load Dataset YAML", self) load_yaml_action.triggered.connect(self.load_yolo_yaml) training_submenu.addAction(load_yaml_action) train_action = QAction("Train Model", self) train_action.triggered.connect(self.show_train_dialog) training_submenu.addAction(train_action) save_model_action = QAction("Save Model", self) save_model_action.triggered.connect(self.save_yolo_model) training_submenu.addAction(save_model_action) # Prediction Settings submenu prediction_submenu = yolo_menu.addMenu("Prediction Settings") load_model_action = QAction("Load Model", self) load_model_action.triggered.connect(self.load_prediction_model) prediction_submenu.addAction(load_model_action) set_threshold_action = QAction("Set Confidence Threshold", self) set_threshold_action.triggered.connect(self.set_confidence_threshold) prediction_submenu.addAction(set_threshold_action) def load_yolo_model(self): if not hasattr(self, 'current_project_dir'): QMessageBox.warning(self, "No Project", "Please open or create a project first.") return if not self.yolo_trainer: self.initialize_yolo_trainer() if self.yolo_trainer.load_model(): QMessageBox.information(self, "Model Loaded", "YOLO model loaded successfully.") else: QMessageBox.warning(self, "Load Cancelled", "Model loading was cancelled.") def prepare_yolo_dataset(self): if not hasattr(self, 'current_project_file'): QMessageBox.warning(self, "No Project", "Please open or create a project first.") return if not self.yolo_trainer: self.initialize_yolo_trainer() try: yaml_path = self.yolo_trainer.prepare_dataset() QMessageBox.information(self, "Dataset Prepared", f"YOLO dataset prepared successfully. YAML file: {yaml_path}") except Exception as e: QMessageBox.critical(self, "Error", f"An error occurred while preparing the dataset: {str(e)}") def load_yolo_yaml(self): if not hasattr(self, 'current_project_file'): QMessageBox.warning(self, "No Project", "Please open or create a project first.") return if not self.yolo_trainer: self.initialize_yolo_trainer() try: if self.yolo_trainer.load_yaml(): QMessageBox.information(self, "YAML Loaded", "Dataset YAML loaded successfully.") else: QMessageBox.warning(self, "Load Cancelled", "YAML loading was cancelled.") except Exception as e: QMessageBox.critical(self, "Error", f"An error occurred while loading the YAML file: {str(e)}") def save_yolo_model(self): if not hasattr(self, 'current_project_file'): QMessageBox.warning(self, "No Project", "Please open or create a project first.") return if not self.yolo_trainer or not self.yolo_trainer.model: QMessageBox.warning(self, "No Model", "Please train or load a YOLO model first.") return try: if self.yolo_trainer.save_model(): QMessageBox.information(self, "Model Saved", "YOLO model saved successfully.") else: QMessageBox.warning(self, "Save Cancelled", "Model saving was cancelled.") except Exception as e: QMessageBox.critical(self, "Error", f"An error occurred while saving the model: {str(e)}") def load_prediction_model(self): if not hasattr(self, 'current_project_file'): QMessageBox.warning(self, "No Project", "Please open or create a project first.") return if not self.yolo_trainer: self.initialize_yolo_trainer() dialog = LoadPredictionModelDialog(self) if dialog.exec_() == QDialog.Accepted: model_path = dialog.model_path yaml_path = dialog.yaml_path if model_path and yaml_path: try: result, message = self.yolo_trainer.load_prediction_model(model_path, yaml_path) if result: QMessageBox.information(self, "Model Loaded", "YOLO model and YAML file loaded successfully for prediction.") if message: QMessageBox.warning(self, "Class Mismatch Warning", message) else: QMessageBox.critical(self, "Error Loading Model", f"Could not load the model or YAML file: {message}") except Exception as e: QMessageBox.critical(self, "Error", f"An error occurred: {str(e)}") else: QMessageBox.warning(self, "Files Required", "Both model and YAML files are required for prediction.") def show_train_dialog(self): if not self.yolo_trainer: QMessageBox.warning(self, "No Project", "Please open or create a project first.") return if not self.yolo_trainer.model: QMessageBox.warning(self, "No Model", "Please load a pre-trained model first.") return if not self.yolo_trainer.yaml_path: QMessageBox.warning(self, "No Dataset", "Please prepare or load a dataset YAML first.") return dialog = QDialog(self) dialog.setWindowTitle("Train YOLO Model") layout = QVBoxLayout() epochs_label = QLabel("Number of Epochs:") epochs_input = QLineEdit("100") layout.addWidget(epochs_label) layout.addWidget(epochs_input) imgsz_label = QLabel("Image Size:") imgsz_input = QLineEdit("640") layout.addWidget(imgsz_label) layout.addWidget(imgsz_input) button_box = QDialogButtonBox(QDialogButtonBox.Ok | QDialogButtonBox.Cancel) button_box.accepted.connect(dialog.accept) button_box.rejected.connect(dialog.reject) layout.addWidget(button_box) dialog.setLayout(layout) if dialog.exec_() == QDialog.Accepted: epochs = int(epochs_input.text()) imgsz = int(imgsz_input.text()) self.start_training(epochs, imgsz) def initialize_yolo_trainer(self): if hasattr(self, 'current_project_dir'): self.yolo_trainer = YOLOTrainer(self.current_project_dir, self) else: QMessageBox.warning(self, "No Project", "Please open or create a project first.") def start_training(self, epochs, imgsz): if not hasattr(self, 'training_dialog'): self.training_dialog = TrainingInfoDialog(self) self.training_dialog.show() self.yolo_trainer.progress_signal.connect(self.training_dialog.update_info) self.yolo_trainer.set_progress_callback(self.training_dialog.update_info) self.training_dialog.stop_signal.connect(self.yolo_trainer.stop_training_signal) self.training_thread = TrainingThread(self.yolo_trainer, epochs, imgsz) self.training_thread.finished.connect(self.training_finished) self.training_thread.start() def training_finished(self, results): self.training_dialog.stop_button.setEnabled(True) self.training_dialog.stop_button.setText("Stop Training") self.yolo_trainer.progress_signal.disconnect(self.training_dialog.update_info) self.training_dialog.stop_signal.disconnect(self.yolo_trainer.stop_training_signal) if isinstance(results, str): QMessageBox.critical(self, "Training Error", f"An error occurred during training: {results}") else: QMessageBox.information(self, "Training Complete", "YOLO model training completed successfully.") def set_confidence_threshold(self): if not hasattr(self, 'current_project_file'): QMessageBox.warning(self, "No Project", "Please open or create a project first.") return if not self.yolo_trainer: self.initialize_yolo_trainer() current_threshold = self.yolo_trainer.conf_threshold new_threshold, ok = QInputDialog.getDouble(self, "Set Confidence Threshold", "Enter confidence threshold (0-1):", current_threshold, 0, 1, 2) if ok: self.yolo_trainer.set_conf_threshold(new_threshold) QMessageBox.information(self, "Threshold Updated", f"Confidence threshold set to {new_threshold}") def show_predict_dialog(self): if not self.yolo_trainer or not self.yolo_trainer.model: QMessageBox.warning(self, "No Model", "Please load a YOLO model first.") return dialog = QDialog(self) dialog.setWindowTitle("Predict with YOLO Model") layout = QVBoxLayout() image_list = QListWidget() for image_name in self.image_paths.keys(): image_list.addItem(image_name) layout.addWidget(QLabel("Select images for prediction:")) layout.addWidget(image_list) conf_label = QLabel("Confidence Threshold:") conf_input = QDoubleSpinBox() conf_input.setRange(0, 1) conf_input.setSingleStep(0.01) conf_input.setValue(self.yolo_trainer.conf_threshold) layout.addWidget(conf_label) layout.addWidget(conf_input) button_box = QDialogButtonBox(QDialogButtonBox.Cancel) predict_button = QPushButton("Predict") button_box.addButton(predict_button, QDialogButtonBox.AcceptRole) button_box.accepted.connect(dialog.accept) button_box.rejected.connect(dialog.reject) layout.addWidget(button_box) dialog.setLayout(layout) if dialog.exec_() == QDialog.Accepted: selected_images = [item.text() for item in image_list.selectedItems()] conf = conf_input.value() self.yolo_trainer.set_conf_threshold(conf) self.run_predictions(selected_images) def run_predictions(self, selected_images): for image_name in selected_images: image_path = self.image_paths[image_name] results = self.yolo_trainer.predict(image_path) self.process_yolo_results(results, image_name) def process_yolo_results(self, results, image_name): image_path = self.image_paths[image_name] image = cv2.imread(image_path) if image is None: QMessageBox.warning(self, "Error", f"Failed to load image: {image_name}") return original_height, original_width = image.shape[:2] temp_annotations = {} try: results, input_size, original_size = results # Unpack the results, input size, and original size input_height, input_width = input_size orig_height, orig_width = original_size scale_x = original_width / orig_width scale_y = original_height / orig_height for result in results: boxes = result.boxes masks = result.masks if masks is None: print(f"No masks found for {image_name}") continue for mask, box in zip(masks, boxes): try: class_id = int(box.cls) class_name = self.yolo_trainer.class_names[class_id] score = float(box.conf) mask_array = mask.data.cpu().numpy()[0] # Resize mask to original image size mask_array = cv2.resize(mask_array, (orig_width, orig_height)) contours, _ = cv2.findContours((mask_array > 0.5).astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) if contours: epsilon = 0.005 * cv2.arcLength(contours[0], True) approx = cv2.approxPolyDP(contours[0], epsilon, True) polygon = approx.flatten().tolist() # Scale the polygon coordinates scaled_polygon = [] for i in range(0, len(polygon), 2): x = polygon[i] * scale_x y = polygon[i+1] * scale_y scaled_polygon.extend([x, y]) temp_class_name = f"Temp-{class_name}" if temp_class_name not in temp_annotations: temp_annotations[temp_class_name] = [] temp_annotation = { "segmentation": scaled_polygon, "category_name": temp_class_name, "score": score, "temp": True } temp_annotations[temp_class_name].append(temp_annotation) except IndexError: QMessageBox.warning(self, "Class Mismatch", "There is a mismatch between the model and the YAML file classes. " "Please check that the YAML file corresponds to the loaded model.") return except Exception as e: QMessageBox.warning(self, "Prediction Error", f"An error occurred during prediction: {str(e)}\n\n" "This might be due to a mismatch between the model and the YAML file classes. " "Please check that the YAML file corresponds to the loaded model.") return self.add_temp_classes(temp_annotations) self.update_class_list() self.image_label.update() if temp_annotations: total_predictions = sum(len(anns) for anns in temp_annotations.values()) QMessageBox.information(self, "Review Predictions", f"Found {total_predictions} predictions for {len(temp_annotations)} classes.\n" "Use class visibility checkboxes to review.\n" "Press Enter to accept or Esc to reject visible predictions.") else: QMessageBox.information(self, "No Predictions", "No predictions were found for this image.") # Deactivate SAM tool self.deactivate_sam_magic_wand() def add_temp_classes(self, temp_annotations): for temp_class_name, annotations in temp_annotations.items(): if temp_class_name not in self.image_label.class_colors: color = QColor(Qt.GlobalColor(len(self.image_label.class_colors) % 16 + 7)) self.image_label.class_colors[temp_class_name] = color self.image_label.annotations[temp_class_name] = annotations self.update_class_list() def verify_current_class(self): if self.current_class is None or self.current_class not in self.class_mapping: if self.class_list.count() > 0: self.class_list.setCurrentRow(0) self.on_class_selected(self.class_list.item(0)) else: self.current_class = None self.disable_annotation_tools() def accept_visible_temp_classes(self): visible_temp_classes = [item.text() for item in self.class_list.findItems("Temp-*", Qt.MatchWildcard) if item.checkState() == Qt.Checked] for temp_class_name in visible_temp_classes: permanent_class_name = temp_class_name[5:] # Remove "Temp-" prefix if permanent_class_name not in self.image_label.annotations: self.add_class(permanent_class_name, self.image_label.class_colors[temp_class_name]) # Get the current maximum number for this class current_max = max([ann.get('number', 0) for ann in self.image_label.annotations.get(permanent_class_name, [])] + [0]) for annotation in self.image_label.annotations[temp_class_name]: current_max += 1 annotation['category_name'] = permanent_class_name annotation['number'] = current_max self.image_label.annotations.setdefault(permanent_class_name, []).append(annotation) del self.image_label.annotations[temp_class_name] del self.image_label.class_colors[temp_class_name] self.update_class_list() current_name = self.current_slice or self.image_file_name self.all_annotations[current_name] = self.image_label.annotations self.update_annotation_list() self.image_label.update() self.save_current_annotations() # Select the first primary class self.select_first_primary_class() self.verify_current_class() QMessageBox.information(self, "Annotations Accepted", "Temporary annotations have been accepted and added to the permanent classes.") def select_first_primary_class(self): for i in range(self.class_list.count()): item = self.class_list.item(i) if not item.text().startswith("Temp-"): self.class_list.setCurrentItem(item) self.on_class_selected(item) break def reject_visible_temp_classes(self): visible_temp_classes = [item.text() for item in self.class_list.findItems("Temp-*", Qt.MatchWildcard) if item.checkState() == Qt.Checked] for temp_class_name in visible_temp_classes: if temp_class_name in self.image_label.annotations: del self.image_label.annotations[temp_class_name] if temp_class_name in self.image_label.class_colors: del self.image_label.class_colors[temp_class_name] self.update_class_list() self.image_label.update() def is_class_visible(self, class_name): items = self.class_list.findItems(class_name, Qt.MatchExactly) if items: return items[0].checkState() == Qt.Checked return False def check_temp_annotations(self): temp_classes = [class_name for class_name in self.image_label.annotations.keys() if class_name.startswith("Temp-")] if temp_classes: reply = QMessageBox.question(self, 'Temporary Annotations', "There are temporary annotations that will be discarded. Do you want to continue?", QMessageBox.Yes | QMessageBox.No, QMessageBox.No) if reply == QMessageBox.Yes: for temp_class in temp_classes: del self.image_label.annotations[temp_class] del self.image_label.class_colors[temp_class] self.update_class_list() self.update_annotation_list() return True return False return True def remove_all_temp_annotations(self): for image_name in list(self.all_annotations.keys()): for class_name in list(self.all_annotations[image_name].keys()): if class_name.startswith("Temp-"): del self.all_annotations[image_name][class_name] if not self.all_annotations[image_name]: del self.all_annotations[image_name] for class_name in list(self.image_label.class_colors.keys()): if class_name.startswith("Temp-"): del self.image_label.class_colors[class_name] self.update_class_list() self.update_annotation_list() self.image_label.update() ================================================ FILE: src/digitalsreeni_image_annotator/coco_json_combiner.py ================================================ import json import os from PyQt5.QtWidgets import (QDialog, QVBoxLayout, QHBoxLayout, QPushButton, QFileDialog, QLabel, QMessageBox, QApplication) from PyQt5.QtCore import Qt class COCOJSONCombinerDialog(QDialog): def __init__(self, parent=None): super().__init__(parent) self.setWindowTitle("COCO JSON Combiner") self.setGeometry(100, 100, 400, 300) self.setWindowFlags(self.windowFlags() | Qt.Window) self.setWindowModality(Qt.ApplicationModal) self.json_files = [] self.initUI() def initUI(self): layout = QVBoxLayout() self.file_labels = [] for i in range(5): file_layout = QHBoxLayout() label = QLabel(f"File {i+1}: Not selected") self.file_labels.append(label) file_layout.addWidget(label) select_button = QPushButton(f"Select File {i+1}") select_button.clicked.connect(lambda checked, x=i: self.select_file(x)) file_layout.addWidget(select_button) layout.addLayout(file_layout) self.combine_button = QPushButton("Combine JSON Files") self.combine_button.clicked.connect(self.combine_json_files) self.combine_button.setEnabled(False) layout.addWidget(self.combine_button) self.setLayout(layout) def select_file(self, index): file_name, _ = QFileDialog.getOpenFileName(self, f"Select COCO JSON File {index+1}", "", "JSON Files (*.json)") if file_name: if file_name not in self.json_files: self.json_files.append(file_name) self.file_labels[index].setText(f"File {index+1}: {os.path.basename(file_name)}") self.combine_button.setEnabled(True) else: QMessageBox.warning(self, "Duplicate File", "This file has already been selected.") QApplication.processEvents() def combine_json_files(self): if not self.json_files: QMessageBox.warning(self, "No Files", "Please select at least one JSON file to combine.") return combined_data = { "images": [], "annotations": [], "categories": [] } image_file_names = set() next_image_id = 1 next_annotation_id = 1 try: for file_path in self.json_files: with open(file_path, 'r') as f: data = json.load(f) # Combine categories category_id_map = {} for category in data.get('categories', []): existing_category = next((c for c in combined_data['categories'] if c['name'] == category['name']), None) if existing_category: category_id_map[category['id']] = existing_category['id'] else: new_id = len(combined_data['categories']) + 1 category_id_map[category['id']] = new_id category['id'] = new_id combined_data['categories'].append(category) # Combine images and annotations image_id_map = {} for image in data.get('images', []): if image['file_name'] not in image_file_names: image_file_names.add(image['file_name']) image_id_map[image['id']] = next_image_id image['id'] = next_image_id combined_data['images'].append(image) next_image_id += 1 for annotation in data.get('annotations', []): if annotation['image_id'] in image_id_map: annotation['id'] = next_annotation_id annotation['image_id'] = image_id_map[annotation['image_id']] annotation['category_id'] = category_id_map[annotation['category_id']] combined_data['annotations'].append(annotation) next_annotation_id += 1 output_file, _ = QFileDialog.getSaveFileName(self, "Save Combined JSON", "", "JSON Files (*.json)") if output_file: with open(output_file, 'w') as f: json.dump(combined_data, f, indent=2) QMessageBox.information(self, "Success", f"Combined JSON saved to {output_file}") except Exception as e: QMessageBox.critical(self, "Error", f"An error occurred while combining JSON files: {str(e)}") def show_centered(self, parent): parent_geo = parent.geometry() self.move(parent_geo.center() - self.rect().center()) self.show() def show_coco_json_combiner(parent): dialog = COCOJSONCombinerDialog(parent) dialog.show_centered(parent) return dialog ================================================ FILE: src/digitalsreeni_image_annotator/constants.py ================================================ """ Constants for the Image Annotator application. This module contains constant values used across the application. @DigitalSreeni Dr. Sreenivas Bhattiprolu """ # File dialog filters IMAGE_FILE_FILTER = "Image Files (*.png *.jpg *.bmp)" JSON_FILE_FILTER = "JSON Files (*.json)" # Default window size DEFAULT_WINDOW_WIDTH = 1400 DEFAULT_WINDOW_HEIGHT = 800 # Zoom settings MIN_ZOOM = 10 MAX_ZOOM = 500 DEFAULT_ZOOM = 100 # Annotation settings DEFAULT_FILL_OPACITY = 0.3 ================================================ FILE: src/digitalsreeni_image_annotator/dataset_splitter.py ================================================ import os import json import shutil import random from PyQt5.QtWidgets import (QDialog, QVBoxLayout, QHBoxLayout, QPushButton, QFileDialog, QLabel, QSpinBox, QRadioButton, QButtonGroup, QMessageBox, QComboBox) from PyQt5.QtCore import Qt import yaml from PIL import Image class DatasetSplitterTool(QDialog): def __init__(self, parent=None): super().__init__(parent) self.setWindowTitle("Dataset Splitter") self.setGeometry(100, 100, 500, 300) self.setWindowFlags(self.windowFlags() | Qt.Window) self.initUI() def initUI(self): layout = QVBoxLayout() # Option selection options_layout = QVBoxLayout() self.images_only_radio = QRadioButton("Images Only") options_layout.addWidget(self.images_only_radio) images_annotations_layout = QHBoxLayout() self.images_annotations_radio = QRadioButton("Images and Annotations") images_annotations_layout.addWidget(self.images_annotations_radio) self.select_json_button = QPushButton("Upload COCO JSON File") self.select_json_button.clicked.connect(self.select_json_file) self.select_json_button.setEnabled(False) images_annotations_layout.addWidget(self.select_json_button) options_layout.addLayout(images_annotations_layout) layout.addLayout(options_layout) option_group = QButtonGroup(self) option_group.addButton(self.images_only_radio) option_group.addButton(self.images_annotations_radio) self.images_only_radio.setChecked(True) # Percentage inputs train_layout = QHBoxLayout() train_layout.addWidget(QLabel("Train %:")) self.train_percent = QSpinBox() self.train_percent.setRange(0, 100) self.train_percent.setValue(70) train_layout.addWidget(self.train_percent) layout.addLayout(train_layout) val_layout = QHBoxLayout() val_layout.addWidget(QLabel("Validation %:")) self.val_percent = QSpinBox() self.val_percent.setRange(0, 100) self.val_percent.setValue(30) val_layout.addWidget(self.val_percent) layout.addLayout(val_layout) test_layout = QHBoxLayout() test_layout.addWidget(QLabel("Test %:")) self.test_percent = QSpinBox() self.test_percent.setRange(0, 100) self.test_percent.setValue(0) test_layout.addWidget(self.test_percent) layout.addLayout(test_layout) # Format selection self.format_selection_layout = QHBoxLayout() self.format_label = QLabel("Output Format:") self.format_combo = QComboBox() self.format_combo.addItems(["COCO JSON", "YOLO"]) self.format_combo.setEnabled(False) self.format_selection_layout.addWidget(self.format_label) self.format_selection_layout.addWidget(self.format_combo) options_layout.addLayout(self.format_selection_layout) # Buttons self.select_input_button = QPushButton("Select Input Directory") self.select_input_button.clicked.connect(self.select_input_directory) layout.addWidget(self.select_input_button) self.select_output_button = QPushButton("Select Output Directory") self.select_output_button.clicked.connect(self.select_output_directory) layout.addWidget(self.select_output_button) self.split_button = QPushButton("Split Dataset") self.split_button.clicked.connect(self.split_dataset) layout.addWidget(self.split_button) self.setLayout(layout) self.input_directory = "" self.output_directory = "" self.json_file = "" # Connect radio buttons to enable/disable JSON selection self.images_only_radio.toggled.connect(self.toggle_json_selection) self.images_annotations_radio.toggled.connect(self.toggle_json_selection) def toggle_json_selection(self): is_annotations = self.images_annotations_radio.isChecked() self.select_json_button.setEnabled(is_annotations) self.format_combo.setEnabled(is_annotations) def select_input_directory(self): self.input_directory = QFileDialog.getExistingDirectory(self, "Select Input Directory") def select_output_directory(self): self.output_directory = QFileDialog.getExistingDirectory(self, "Select Output Directory") def select_json_file(self): self.json_file, _ = QFileDialog.getOpenFileName(self, "Select COCO JSON File", "", "JSON Files (*.json)") def split_dataset(self): if not self.input_directory or not self.output_directory: QMessageBox.warning(self, "Error", "Please select input and output directories.") return if self.images_annotations_radio.isChecked() and not self.json_file: QMessageBox.warning(self, "Error", "Please select a COCO JSON file.") return train_percent = self.train_percent.value() val_percent = self.val_percent.value() test_percent = self.test_percent.value() if train_percent + val_percent + test_percent != 100: QMessageBox.warning(self, "Error", "Percentages must add up to 100%.") return if self.images_only_radio.isChecked(): self.split_images_only() else: self.split_images_and_annotations() def split_images_only(self): image_files = [f for f in os.listdir(self.input_directory) if f.lower().endswith(('.png', '.jpg', '.jpeg', '.tif', '.tiff'))] random.shuffle(image_files) train_split = int(len(image_files) * self.train_percent.value() / 100) val_split = int(len(image_files) * self.val_percent.value() / 100) train_images = image_files[:train_split] val_images = image_files[train_split:train_split + val_split] test_images = image_files[train_split + val_split:] for subset, images in [("train", train_images), ("val", val_images), ("test", test_images)]: if images: # Only create directories and copy images if there are images for this split subset_dir = os.path.join(self.output_directory, subset) os.makedirs(subset_dir, exist_ok=True) self.copy_images(images, subset, images_only=True) QMessageBox.information(self, "Success", "Dataset split successfully!") def split_images_and_annotations(self): with open(self.json_file, 'r') as f: coco_data = json.load(f) image_files = [img['file_name'] for img in coco_data['images']] random.shuffle(image_files) train_split = int(len(image_files) * self.train_percent.value() / 100) val_split = int(len(image_files) * self.val_percent.value() / 100) train_images = image_files[:train_split] val_images = image_files[train_split:train_split + val_split] test_images = image_files[train_split + val_split:] # Create main directories os.makedirs(self.output_directory, exist_ok=True) if self.format_combo.currentText() == "COCO JSON": self.split_coco_format(coco_data, train_images, val_images, test_images) else: # YOLO format self.split_yolo_format(coco_data, train_images, val_images, test_images) def copy_images(self, image_list, subset, images_only=False): if not image_list: return if images_only: subset_dir = os.path.join(self.output_directory, subset) else: subset_dir = os.path.join(self.output_directory, subset, "images") os.makedirs(subset_dir, exist_ok=True) for image in image_list: src = os.path.join(self.input_directory, image) dst = os.path.join(subset_dir, image) shutil.copy2(src, dst) def create_subset_annotations(self, coco_data, subset_images): subset_images_data = [img for img in coco_data['images'] if img['file_name'] in subset_images] subset_image_ids = [img['id'] for img in subset_images_data] return { "images": subset_images_data, "annotations": [ann for ann in coco_data['annotations'] if ann['image_id'] in subset_image_ids], "categories": coco_data['categories'] } def split_coco_format(self, coco_data, train_images, val_images, test_images): # Only create directories and save annotations for non-empty splits for subset, images in [("train", train_images), ("val", val_images), ("test", test_images)]: if images: # Only process if there are images in this split subset_dir = os.path.join(self.output_directory, subset) os.makedirs(subset_dir, exist_ok=True) # Create the subset directory first os.makedirs(os.path.join(subset_dir, "images"), exist_ok=True) self.copy_images(images, subset, images_only=False) # Create and save annotations for this subset subset_data = self.create_subset_annotations(coco_data, images) self.save_coco_annotations(subset_data, subset) QMessageBox.information(self, "Success", "Dataset and COCO annotations split successfully!") def save_coco_annotations(self, data, subset): subset_dir = os.path.join(self.output_directory, subset) os.makedirs(subset_dir, exist_ok=True) output_file = os.path.join(subset_dir, f"{subset}_annotations.json") with open(output_file, 'w') as f: json.dump(data, f, indent=2) def split_yolo_format(self, coco_data, train_images, val_images, test_images): # Create directories only for non-empty splits yaml_paths = {} for subset, images in [("train", train_images), ("val", val_images), ("test", test_images)]: if images: # Only create directories if there are images for this split subset_dir = os.path.join(self.output_directory, subset) os.makedirs(os.path.join(subset_dir, "images"), exist_ok=True) os.makedirs(os.path.join(subset_dir, "labels"), exist_ok=True) yaml_paths[subset] = f'./{subset}/images' # Create class mapping (COCO to YOLO indices) categories = {cat["id"]: i for i, cat in enumerate(coco_data["categories"])} # Process each non-empty subset for subset, images in [("train", train_images), ("val", val_images), ("test", test_images)]: if not images: # Skip if no images in this split continue images_dir = os.path.join(self.output_directory, subset, "images") labels_dir = os.path.join(self.output_directory, subset, "labels") for image_file in images: # Copy image src = os.path.join(self.input_directory, image_file) shutil.copy2(src, os.path.join(images_dir, image_file)) # Get image dimensions img = Image.open(src) img_width, img_height = img.size # Get annotations for this image image_id = next(img["id"] for img in coco_data["images"] if img["file_name"] == image_file) annotations = [ann for ann in coco_data["annotations"] if ann["image_id"] == image_id] # Create YOLO format labels label_file = os.path.join(labels_dir, os.path.splitext(image_file)[0] + ".txt") with open(label_file, "w") as f: for ann in annotations: # Convert COCO class id to YOLO class id yolo_class = categories[ann["category_id"]] # Convert COCO bbox to YOLO format x, y, w, h = ann["bbox"] x_center = (x + w/2) / img_width y_center = (y + h/2) / img_height w = w / img_width h = h / img_height f.write(f"{yolo_class} {x_center:.6f} {y_center:.6f} {w:.6f} {h:.6f}\n") # Create data.yaml with only the relevant paths yaml_data = { 'nc': len(categories), 'names': [cat["name"] for cat in sorted(coco_data["categories"], key=lambda x: categories[x["id"]])] } yaml_data.update(yaml_paths) # Add only paths for non-empty splits with open(os.path.join(self.output_directory, 'data.yaml'), 'w') as f: yaml.dump(yaml_data, f, default_flow_style=False) QMessageBox.information(self, "Success", "Dataset and YOLO annotations split successfully!") def show_centered(self, parent): parent_geo = parent.geometry() self.move(parent_geo.center() - self.rect().center()) self.show() ================================================ FILE: src/digitalsreeni_image_annotator/default_stylesheet.py ================================================ default_stylesheet = """ QWidget { background-color: #F0F0F0; color: #333333; font-family: Arial, sans-serif; } QMainWindow { background-color: #FFFFFF; } QPushButton { background-color: #E0E0E0; border: 1px solid #BBBBBB; padding: 5px 10px; border-radius: 3px; color: #333333; } QPushButton:hover { background-color: #D0D0D0; } QPushButton:pressed { background-color: #C0C0C0; } QPushButton:checked { background-color: #A0A0A0; border: 2px solid #808080; color: #FFFFFF; } QListWidget, QTreeWidget { background-color: #FFFFFF; border: 1px solid #CCCCCC; border-radius: 3px; } QListWidget::item:selected { background-color: #E0E0E0; color: #333333; } QLabel { color: #333333; } QLabel.section-header { font-weight: bold; font-size: 14px; padding: 5px 0; color: #333333; /* Dark color for visibility in light mode */ } QLineEdit, QTextEdit, QPlainTextEdit { background-color: #FFFFFF; border: 1px solid #CCCCCC; color: #333333; padding: 2px; border-radius: 3px; } QSlider::groove:horizontal { background: #CCCCCC; height: 8px; border-radius: 4px; } QSlider::handle:horizontal { background: #888888; width: 18px; margin-top: -5px; margin-bottom: -5px; border-radius: 9px; } QSlider::handle:horizontal:hover { background: #666666; } QScrollBar:vertical, QScrollBar:horizontal { background-color: #F0F0F0; width: 12px; height: 12px; } QScrollBar::handle:vertical, QScrollBar::handle:horizontal { background-color: #CCCCCC; border-radius: 6px; min-height: 20px; } QScrollBar::handle:vertical:hover, QScrollBar::handle:horizontal:hover { background-color: #BBBBBB; } QScrollBar::add-line, QScrollBar::sub-line { background: none; } QMenuBar { background-color: #F0F0F0; } QMenuBar::item { padding: 5px 10px; background-color: transparent; } QMenuBar::item:selected { background-color: #E0E0E0; } QMenu { background-color: #FFFFFF; border: 1px solid #CCCCCC; } QMenu::item { padding: 5px 20px 5px 20px; } QMenu::item:selected { background-color: #E0E0E0; } QToolTip { background-color: #FFFFFF; color: #333333; border: 1px solid #CCCCCC; } QStatusBar { background-color: #F0F0F0; color: #666666; } QListWidget::item { color: none; } """ ================================================ FILE: src/digitalsreeni_image_annotator/dicom_converter.py ================================================ import os import json import numpy as np from datetime import datetime from PyQt5.QtWidgets import (QDialog, QVBoxLayout, QHBoxLayout, QPushButton, QFileDialog, QLabel, QProgressDialog, QRadioButton, QButtonGroup, QMessageBox, QApplication, QGroupBox) from PyQt5.QtCore import Qt import pydicom from pydicom.pixel_data_handlers.util import apply_voi_lut import tifffile class DicomConverter(QDialog): def __init__(self, parent=None): super().__init__(parent) self.setWindowTitle("DICOM to TIFF Converter") self.setGeometry(100, 100, 600, 300) self.setWindowFlags(self.windowFlags() | Qt.Window) self.setWindowModality(Qt.ApplicationModal) # Add modal behavior # Initialize variables first self.input_file = "" self.output_directory = "" self.initUI() def initUI(self): layout = QVBoxLayout() layout.setSpacing(10) # Add consistent spacing # File Selection Group file_group = QGroupBox("File Selection") file_layout = QVBoxLayout() # Input file selection input_layout = QHBoxLayout() self.input_label = QLabel("No DICOM file selected") self.input_label.setMinimumWidth(100) self.input_label.setMaximumWidth(300) self.input_label.setWordWrap(True) self.select_input_btn = QPushButton("Select DICOM File") self.select_input_btn.clicked.connect(self.select_input) input_layout.addWidget(self.select_input_btn) input_layout.addWidget(self.input_label, 1) file_layout.addLayout(input_layout) # Output directory selection output_layout = QHBoxLayout() self.output_label = QLabel("No output directory selected") self.output_label.setMinimumWidth(100) self.output_label.setMaximumWidth(300) self.output_label.setWordWrap(True) self.select_output_btn = QPushButton("Select Output Directory") self.select_output_btn.clicked.connect(self.select_output) output_layout.addWidget(self.select_output_btn) output_layout.addWidget(self.output_label, 1) file_layout.addLayout(output_layout) file_group.setLayout(file_layout) layout.addWidget(file_group) # Output Format Group format_group = QGroupBox("Output Format") format_layout = QVBoxLayout() self.stack_radio = QRadioButton("Single TIFF Stack") self.individual_radio = QRadioButton("Individual TIFF Files") self.stack_radio.setChecked(True) format_layout.addWidget(self.stack_radio) format_layout.addWidget(self.individual_radio) format_group.setLayout(format_layout) layout.addWidget(format_group) # Metadata info metadata_group = QGroupBox("Metadata Information") metadata_layout = QVBoxLayout() metadata_label = QLabel("DICOM metadata will be saved as JSON file in the output directory") metadata_label.setStyleSheet("color: gray; font-style: italic;") metadata_label.setWordWrap(True) metadata_layout.addWidget(metadata_label) metadata_group.setLayout(metadata_layout) layout.addWidget(metadata_group) # Convert button self.convert_btn = QPushButton("Convert") self.convert_btn.clicked.connect(self.convert_dicom) layout.addWidget(self.convert_btn) self.setLayout(layout) def select_input(self): try: file_filter = "DICOM files (*.dcm *.DCM);;All files (*.*)" file_name, _ = QFileDialog.getOpenFileName( self, "Select DICOM File", "", file_filter, options=QFileDialog.Options() ) if file_name: self.input_file = file_name self.input_label.setText(self.truncate_path(file_name)) self.input_label.setToolTip(file_name) QApplication.processEvents() except Exception as e: QMessageBox.critical(self, "Error", f"Error selecting input file: {str(e)}") def select_output(self): try: directory = QFileDialog.getExistingDirectory( self, "Select Output Directory", "", QFileDialog.ShowDirsOnly | QFileDialog.DontResolveSymlinks ) if directory: self.output_directory = directory self.output_label.setText(self.truncate_path(directory)) self.output_label.setToolTip(directory) QApplication.processEvents() except Exception as e: QMessageBox.critical(self, "Error", f"Error selecting output directory: {str(e)}") def truncate_path(self, path, max_length=40): if len(path) <= max_length: return path filename = os.path.basename(path) directory = os.path.dirname(path) if len(filename) > max_length - 5: return f"...{filename[-(max_length-5):]}" available_length = max_length - len(filename) - 5 return f"...{directory[-available_length:]}{os.sep}{filename}" def extract_metadata(self, ds): """Extract relevant metadata from DICOM dataset.""" metadata = { "PatientID": getattr(ds, "PatientID", "Unknown"), "PatientName": str(getattr(ds, "PatientName", "Unknown")), "StudyDate": getattr(ds, "StudyDate", "Unknown"), "SeriesDescription": getattr(ds, "SeriesDescription", "Unknown"), "Modality": getattr(ds, "Modality", "Unknown"), "Manufacturer": getattr(ds, "Manufacturer", "Unknown"), "InstitutionName": getattr(ds, "InstitutionName", "Unknown"), "PixelSpacing": getattr(ds, "PixelSpacing", [1, 1]), "SliceThickness": getattr(ds, "SliceThickness", 1), "ImageOrientation": getattr(ds, "ImageOrientationPatient", [1,0,0,0,1,0]), "ImagePosition": getattr(ds, "ImagePositionPatient", [0,0,0]), "WindowCenter": getattr(ds, "WindowCenter", None), "WindowWidth": getattr(ds, "WindowWidth", None), "RescaleIntercept": getattr(ds, "RescaleIntercept", 0), "RescaleSlope": getattr(ds, "RescaleSlope", 1), "BitsAllocated": getattr(ds, "BitsAllocated", 16), "PixelRepresentation": getattr(ds, "PixelRepresentation", 0), "ConversionDate": datetime.now().strftime("%Y-%m-%d %H:%M:%S") } return metadata def apply_window_level(self, image, ds): """Apply window/level if present in DICOM.""" try: if hasattr(ds, 'WindowCenter') and hasattr(ds, 'WindowWidth'): return apply_voi_lut(image, ds) except: pass return image def convert_dicom(self): if not self.input_file or not self.output_directory: QMessageBox.warning(self, "Error", "Please select both input file and output directory") return try: # Create progress dialog progress = QProgressDialog("Processing DICOM file...", "Cancel", 0, 100, self) progress.setWindowModality(Qt.WindowModal) progress.setMinimumWidth(400) progress.show() # Verify DICOM file if not pydicom.misc.is_dicom(self.input_file): raise ValueError("Selected file is not a valid DICOM file") # Read DICOM data print("Reading DICOM file...") progress.setLabelText("Reading DICOM file...") progress.setValue(20) ds = pydicom.dcmread(self.input_file) series_metadata = self.extract_metadata(ds) # Process pixel data print("Processing pixel data...") progress.setLabelText("Processing pixel data...") progress.setValue(40) pixel_array = ds.pixel_array original_dtype = pixel_array.dtype print(f"Original data type: {original_dtype}") print(f"Original data range: {pixel_array.min()} to {pixel_array.max()}") # Apply rescale slope and intercept if hasattr(ds, 'RescaleSlope') or hasattr(ds, 'RescaleIntercept'): slope = getattr(ds, 'RescaleSlope', 1) intercept = getattr(ds, 'RescaleIntercept', 0) print(f"Applying rescale slope ({slope}) and intercept ({intercept})") pixel_array = (pixel_array * slope + intercept) # Apply window/level print("Applying window/level adjustments...") pixel_array = self.apply_window_level(pixel_array, ds) print(f"Adjusted data range: {pixel_array.min()} to {pixel_array.max()}") print(f"Image shape: {pixel_array.shape}") print(f"Original dtype: {original_dtype}") # Save metadata progress.setLabelText("Saving metadata...") progress.setValue(60) metadata_file = os.path.join(self.output_directory, os.path.splitext(os.path.basename(self.input_file))[0] + "_metadata.json") with open(metadata_file, 'w') as f: json.dump(series_metadata, f, indent=2) # Get physical sizes from metadata pixel_spacing = series_metadata.get("PixelSpacing", [1, 1]) slice_thickness = series_metadata.get("SliceThickness", 1) print(f"Pixel spacing: {pixel_spacing}") print(f"Slice thickness: {slice_thickness}") # Save TIFF progress.setLabelText("Saving TIFF file(s)...") progress.setValue(80) # Convert back to original dtype if needed if np.issubdtype(original_dtype, np.integer): print("Converting back to original integer dtype...") data_min = pixel_array.min() data_max = pixel_array.max() if data_max != data_min: pixel_array = ((pixel_array - data_min) / (data_max - data_min) * np.iinfo(original_dtype).max).astype(original_dtype) else: pixel_array = np.zeros_like(pixel_array, dtype=original_dtype) print(f"Final data range: {pixel_array.min()} to {pixel_array.max()}") # Prepare ImageJ metadata imagej_metadata = { 'axes': 'YX', # Will be updated to ZYX for 3D data 'spacing': float(slice_thickness), # Only used for 3D data 'unit': 'um', 'finterval': float(pixel_spacing[0]) # XY pixel size } base_name = os.path.splitext(os.path.basename(self.input_file))[0] if self.stack_radio.isChecked(): # Save as single stack output_file = os.path.join(self.output_directory, f"{base_name}.tif") # Update axes for 3D data if len(pixel_array.shape) > 2: imagej_metadata['axes'] = 'ZYX' print(f"Saving stack with metadata: {imagej_metadata}") tifffile.imwrite( output_file, pixel_array, imagej=True, metadata=imagej_metadata, resolution=(1.0/float(pixel_spacing[0]), 1.0/float(pixel_spacing[1])) ) print(f"Saved stack to: {output_file}") print(f"Stack shape: {pixel_array.shape}") # Replace the individual slices saving section in convert_dicom method with this: else: # For multi-slice DICOM, save individual slices if len(pixel_array.shape) > 2: imagej_metadata['axes'] = 'YX' # Reset to 2D for individual slices total_slices = pixel_array.shape[0] for i in range(total_slices): progress.setLabelText(f"Saving slice {i+1}/{total_slices}...") # Fix: Convert float to integer for progress value progress_value = int(80 + (i/total_slices)*15) progress.setValue(progress_value) QApplication.processEvents() if progress.wasCanceled(): print("Operation cancelled by user") return output_file = os.path.join(self.output_directory, f"{base_name}_slice_{i+1:03d}.tif") print(f"Saving slice {i+1} with metadata: {imagej_metadata}") tifffile.imwrite( output_file, pixel_array[i], imagej=True, metadata=imagej_metadata, resolution=(1.0/float(pixel_spacing[0]), 1.0/float(pixel_spacing[1])) ) print(f"Saved {total_slices} individual slices") else: # Single slice DICOM output_file = os.path.join(self.output_directory, f"{base_name}.tif") print(f"Saving single slice with metadata: {imagej_metadata}") tifffile.imwrite( output_file, pixel_array, imagej=True, metadata=imagej_metadata, resolution=(1.0/float(pixel_spacing[0]), 1.0/float(pixel_spacing[1])) ) print(f"Saved single slice to: {output_file}") progress.setValue(100) # Construct success message msg = "Conversion complete!\n\n" msg += f"DICOM file: {os.path.basename(self.input_file)}\n" msg += f"Output directory: {self.truncate_path(self.output_directory)}\n\n" if self.stack_radio.isChecked(): msg += f"Saved as: {os.path.basename(output_file)}\n" else: if len(pixel_array.shape) > 2: msg += f"Saved {pixel_array.shape[0]} individual slices\n" else: msg += f"Saved as: {os.path.basename(output_file)}\n" msg += f"\nMetadata saved as: {os.path.basename(metadata_file)}\n" msg += f"Pixel spacing: {pixel_spacing[0]}x{pixel_spacing[1]} µm\n" if len(pixel_array.shape) > 2: msg += f"Slice thickness: {slice_thickness} µm" QMessageBox.information(self, "Success", msg) except Exception as e: QMessageBox.critical(self, "Error", str(e)) print(f"Error occurred: {str(e)}") import traceback traceback.print_exc() def show_centered(self, parent): parent_geo = parent.geometry() self.move(parent_geo.center() - self.rect().center()) self.show() QApplication.processEvents() # Ensure window displays properly def show_dicom_converter(parent): dialog = DicomConverter(parent) dialog.show_centered(parent) return dialog ================================================ FILE: src/digitalsreeni_image_annotator/export_formats.py ================================================ import json from PyQt5.QtGui import QImage from .utils import calculate_area, calculate_bbox import yaml import os import shutil import tempfile import xml.etree.ElementTree as ET from xml.dom import minidom from datetime import datetime import numpy as np import skimage.draw from PIL import Image # Utility function to handle the COCO conversion for all export formats def convert_to_coco(all_annotations, class_mapping, image_paths, slices, image_slices): with tempfile.TemporaryDirectory() as temp_dir: json_file_path, images_dir = export_coco_json(all_annotations, class_mapping, image_paths, slices, image_slices, temp_dir) with open(json_file_path, 'r') as f: coco_data = json.load(f) return coco_data, images_dir def export_coco_json(all_annotations, class_mapping, image_paths, slices, image_slices, output_dir, json_filename=None): coco_format = { "images": [], "categories": [{"id": id, "name": name} for name, id in class_mapping.items()], "annotations": [] } # Create images directory images_dir = os.path.join(output_dir, 'images') os.makedirs(images_dir, exist_ok=True) annotation_id = 1 image_id = 1 # Create a mapping of slice names to their QImage objects slice_map = {slice_name: qimage for slice_name, qimage in slices} # Handle all images and slices for image_name, annotations in all_annotations.items(): # Skip if there are no annotations for this image/slice if not annotations: continue # Check if it's a slice (either in slice_map or has underscores and no file extension) is_slice = image_name in slice_map or ('_' in image_name and '.' not in image_name) if is_slice: qimage = slice_map.get(image_name) if qimage is None: # If the slice is not in slice_map, it might be a CZI slice or a TIFF slice # Find the corresponding QImage in slices or image_slices matching_slices = [s for s in slices if s[0] == image_name] if matching_slices: qimage = matching_slices[0][1] else: # Check in image_slices for stack_slices in image_slices.values(): matching_slices = [s for s in stack_slices if s[0] == image_name] if matching_slices: qimage = matching_slices[0][1] break if qimage is None: print(f"No image data found for slice {image_name}, skipping") continue file_name_img = f"{image_name}.png" # Save the QImage as a file save_path = os.path.join(images_dir, file_name_img) if not os.path.exists(save_path): qimage.save(save_path) else: print(f"Image {file_name_img} already exists in the target directory. Skipping save.") else: # Check if the image_name exists in image_paths image_path = next((path for name, path in image_paths.items() if image_name in name), None) if not image_path: print(f"No image path found for {image_name}, skipping") continue if image_path.lower().endswith(('.tif', '.tiff', '.czi')): print(f"Skipping main tiff/czi file: {image_name}") continue file_name_img = image_name # Copy the image file dst_path = os.path.join(images_dir, file_name_img) if not os.path.exists(dst_path): shutil.copy2(image_path, dst_path) else: print(f"Image {file_name_img} already exists in the target directory. Skipping copy.") image_info = { "file_name": file_name_img, "height": qimage.height() if is_slice else QImage(image_path).height(), "width": qimage.width() if is_slice else QImage(image_path).width(), "id": image_id } coco_format["images"].append(image_info) for class_name, class_annotations in annotations.items(): for ann in class_annotations: coco_ann = create_coco_annotation(ann, image_id, annotation_id, class_name, class_mapping) coco_format["annotations"].append(coco_ann) annotation_id += 1 image_id += 1 # Generate JSON filename if not provided if json_filename is None: timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") json_filename = f"annotations_{timestamp}.json" elif not json_filename.lower().endswith('.json'): json_filename += '.json' # Save COCO JSON file json_file_path = os.path.join(output_dir, json_filename) with open(json_file_path, 'w') as f: json.dump(coco_format, f, indent=2) return json_file_path, images_dir def create_coco_annotation(ann, image_id, annotation_id, class_name, class_mapping): coco_ann = { "id": annotation_id, "image_id": image_id, "category_id": class_mapping[class_name], "area": calculate_area(ann), "iscrowd": 0 } if "segmentation" in ann: coco_ann["segmentation"] = [ann["segmentation"]] coco_ann["bbox"] = calculate_bbox(ann["segmentation"]) elif "bbox" in ann: coco_ann["bbox"] = ann["bbox"] return coco_ann def export_yolo_v4(all_annotations, class_mapping, image_paths, slices, image_slices, output_dir): # Create output directories train_dir = os.path.join(output_dir, 'train') valid_dir = os.path.join(output_dir, 'valid') for dir_path in [train_dir, valid_dir]: os.makedirs(os.path.join(dir_path, 'images'), exist_ok=True) os.makedirs(os.path.join(dir_path, 'labels'), exist_ok=True) # Create a mapping of class names to YOLO indices class_to_index = {name: i for i, name in enumerate(class_mapping.keys())} # Create a mapping of slice names to their QImage objects slice_map = {slice_name: qimage for slice_name, qimage in slices} for image_name, annotations in all_annotations.items(): # Skip if there are no annotations for this image/slice if not annotations: continue # For simplicity, we'll put all data in the train directory images_dir = os.path.join(train_dir, 'images') labels_dir = os.path.join(train_dir, 'labels') # Handle image saving (similar to before, but adjusted for new directory structure) if image_name in slice_map or ('_' in image_name and '.' not in image_name): # Handle slice images qimage = slice_map.get(image_name) or next((s[1] for s in slices if s[0] == image_name), None) if qimage is None: for stack_slices in image_slices.values(): qimage = next((s[1] for s in stack_slices if s[0] == image_name), None) if qimage: break if qimage is None: print(f"No image data found for slice {image_name}, skipping") continue file_name_img = f"{image_name}.png" save_path = os.path.join(images_dir, file_name_img) if not os.path.exists(save_path): qimage.save(save_path) img_width, img_height = qimage.width(), qimage.height() else: # Handle regular images image_path = next((path for name, path in image_paths.items() if image_name in name), None) if not image_path or image_path.lower().endswith(('.tif', '.tiff', '.czi')): print(f"Skipping file: {image_name}") continue file_name_img = image_name dst_path = os.path.join(images_dir, file_name_img) if not os.path.exists(dst_path): shutil.copy2(image_path, dst_path) img = QImage(image_path) img_width, img_height = img.width(), img.height() # Write YOLO format annotation label_file = os.path.splitext(file_name_img)[0] + '.txt' with open(os.path.join(labels_dir, label_file), 'w') as f: for class_name, class_annotations in annotations.items(): class_index = class_to_index[class_name] for ann in class_annotations: if 'segmentation' in ann: polygon = ann['segmentation'] normalized_polygon = [coord / img_width if i % 2 == 0 else coord / img_height for i, coord in enumerate(polygon)] f.write(f"{class_index} " + " ".join(map(lambda x: f"{x:.6f}", normalized_polygon)) + "\n") elif 'bbox' in ann: x, y, w, h = ann['bbox'] x_center = (x + w/2) / img_width y_center = (y + h/2) / img_height w = w / img_width h = h / img_height f.write(f"{class_index} {x_center:.6f} {y_center:.6f} {w:.6f} {h:.6f}\n") # Create YAML file names = list(class_mapping.keys()) yaml_data = { 'train': os.path.abspath(os.path.join(train_dir, 'images')), 'val': os.path.abspath(os.path.join(train_dir, 'images')), # Using train as val 'test': '../test/images', # Placeholder 'nc': len(names), 'names': names } # Save YAML file in the output directory yaml_path = os.path.join(output_dir, 'data.yaml') with open(yaml_path, 'w') as f: yaml.dump(yaml_data, f, default_flow_style=False) return train_dir, yaml_path def export_yolo_v5plus(all_annotations, class_mapping, image_paths, slices, image_slices, output_dir): """ Export annotations in YOLO v5+ format. Directory structure: output_dir/ ├── data.yaml ├── images/ │ ├── train/ │ └── val/ └── labels/ ├── train/ └── val/ """ # Create output directories with new structure images_train_dir = os.path.join(output_dir, 'images', 'train') images_val_dir = os.path.join(output_dir, 'images', 'val') labels_train_dir = os.path.join(output_dir, 'labels', 'train') labels_val_dir = os.path.join(output_dir, 'labels', 'val') for dir_path in [images_train_dir, images_val_dir, labels_train_dir, labels_val_dir]: os.makedirs(dir_path, exist_ok=True) # Create a mapping of class names to YOLO indices class_to_index = {name: i for i, name in enumerate(class_mapping.keys())} # Create a mapping of slice names to their QImage objects slice_map = {slice_name: qimage for slice_name, qimage in slices} for image_name, annotations in all_annotations.items(): # Skip if there are no annotations for this image/slice if not annotations: continue # For simplicity, we'll put all data in the train directory # In practice, you might want to implement train/val split logic images_dir = images_train_dir labels_dir = labels_train_dir # Handle image saving (similar logic to the v4 version) if image_name in slice_map or ('_' in image_name and '.' not in image_name): # Handle slice images qimage = slice_map.get(image_name) if qimage is None: for stack_slices in image_slices.values(): qimage = next((s[1] for s in stack_slices if s[0] == image_name), None) if qimage: break if qimage is None: print(f"No image data found for slice {image_name}, skipping") continue file_name_img = f"{image_name}.png" save_path = os.path.join(images_dir, file_name_img) if not os.path.exists(save_path): qimage.save(save_path) img_width, img_height = qimage.width(), qimage.height() else: # Handle regular images image_path = next((path for name, path in image_paths.items() if image_name in name), None) if not image_path or image_path.lower().endswith(('.tif', '.tiff', '.czi')): print(f"Skipping file: {image_name}") continue file_name_img = image_name dst_path = os.path.join(images_dir, file_name_img) if not os.path.exists(dst_path): shutil.copy2(image_path, dst_path) img = QImage(image_path) img_width, img_height = img.width(), img.height() # Write YOLO format annotation label_file = os.path.splitext(file_name_img)[0] + '.txt' with open(os.path.join(labels_dir, label_file), 'w') as f: for class_name, class_annotations in annotations.items(): class_index = class_to_index[class_name] for ann in class_annotations: if 'segmentation' in ann: polygon = ann['segmentation'] normalized_polygon = [coord / img_width if i % 2 == 0 else coord / img_height for i, coord in enumerate(polygon)] f.write(f"{class_index} " + " ".join(map(lambda x: f"{x:.6f}", normalized_polygon)) + "\n") elif 'bbox' in ann: x, y, w, h = ann['bbox'] x_center = (x + w/2) / img_width y_center = (y + h/2) / img_height w = w / img_width h = h / img_height f.write(f"{class_index} {x_center:.6f} {y_center:.6f} {w:.6f} {h:.6f}\n") # Create YAML file names = list(class_mapping.keys()) yaml_data = { 'path': os.path.abspath(output_dir), # Root directory 'train': os.path.join('images', 'train'), # Relative to path 'val': os.path.join('images', 'val'), # Relative to path 'nc': len(names), 'names': names } # Save YAML file in the output directory yaml_path = os.path.join(output_dir, 'data.yaml') with open(yaml_path, 'w') as f: yaml.dump(yaml_data, f, default_flow_style=False) return output_dir, yaml_path def export_labeled_images(all_annotations, class_mapping, image_paths, slices, image_slices, output_dir): # Create output directories images_dir = os.path.join(output_dir, 'images') labeled_images_dir = os.path.join(output_dir, 'labeled_images') os.makedirs(images_dir, exist_ok=True) os.makedirs(labeled_images_dir, exist_ok=True) # Create a dictionary to store class information for the summary class_summary = {class_name: [] for class_name in class_mapping.keys()} # Create directories for each class inside labeled_images_dir for class_name in class_mapping.keys(): os.makedirs(os.path.join(labeled_images_dir, class_name), exist_ok=True) # Create a mapping of slice names to their QImage objects slice_map = {slice_name: qimage for slice_name, qimage in slices} for image_name, annotations in all_annotations.items(): # Skip if there are no annotations for this image/slice if not annotations: continue # Check if it's a slice (either in slice_map or has underscores and no file extension) is_slice = image_name in slice_map or ('_' in image_name and '.' not in image_name) if is_slice: qimage = slice_map.get(image_name) if qimage is None: # If the slice is not in slice_map, it might be a CZI slice or a TIFF slice matching_slices = [s for s in slices if s[0] == image_name] if matching_slices: qimage = matching_slices[0][1] else: # Check in image_slices for stack_slices in image_slices.values(): matching_slices = [s for s in stack_slices if s[0] == image_name] if matching_slices: qimage = matching_slices[0][1] break if qimage is None: print(f"No image data found for slice {image_name}, skipping") continue file_name_img = f"{image_name}.png" # Save the QImage as a file save_path = os.path.join(images_dir, file_name_img) if not os.path.exists(save_path): qimage.save(save_path) else: print(f"Image {file_name_img} already exists in the target directory. Skipping copy.") img_width, img_height = qimage.width(), qimage.height() else: # Check if the image_name exists in image_paths image_path = next((path for name, path in image_paths.items() if image_name in name), None) if not image_path: print(f"No image path found for {image_name}, skipping") continue if image_path.lower().endswith(('.tif', '.tiff', '.czi')): print(f"Skipping main tiff/czi file: {image_name}") continue file_name_img = image_name # Copy the image file dst_path = os.path.join(images_dir, file_name_img) if not os.path.exists(dst_path): shutil.copy2(image_path, dst_path) else: print(f"Image {file_name_img} already exists in the target directory. Skipping copy.") img = Image.open(image_path) img_width, img_height = img.size # Create a dictionary to store masks for each class class_masks = {class_name: np.zeros((img_height, img_width), dtype=np.uint16) for class_name in class_mapping.keys()} for class_name, class_annotations in annotations.items(): mask = class_masks[class_name] for ann in class_annotations: object_number = np.max(mask) + 1 # Increment object number for this class if 'segmentation' in ann: polygon = np.array(ann['segmentation']).reshape(-1, 2) rr, cc = skimage.draw.polygon(polygon[:, 1], polygon[:, 0], (img_height, img_width)) mask[rr, cc] = object_number elif 'bbox' in ann: x, y, w, h = map(int, ann['bbox']) mask[y:y+h, x:x+w] = object_number class_summary[class_name].append(file_name_img) # Save masks for each class for class_name, mask in class_masks.items(): if np.any(mask): # Only save if the mask is not empty mask_filename = f"{os.path.splitext(file_name_img)[0]}_{class_name}_mask.png" mask_path = os.path.join(labeled_images_dir, class_name, mask_filename) Image.fromarray(mask.astype(np.uint16)).save(mask_path) # Create summary text file summary_path = os.path.join(labeled_images_dir, 'class_summary.txt') with open(summary_path, 'w') as f: f.write("Classes (folder names):\n") for class_name, files in class_summary.items(): if files: # Only include classes that have annotations f.write(f"- {class_name}\n") f.write(f" Images: {', '.join(sorted(set(files)))}\n\n") return output_dir def export_semantic_labels(all_annotations, class_mapping, image_paths, slices, image_slices, output_dir): # Create output directories images_dir = os.path.join(output_dir, 'images') segmented_images_dir = os.path.join(output_dir, 'segmented_images') os.makedirs(images_dir, exist_ok=True) os.makedirs(segmented_images_dir, exist_ok=True) # Create a mapping of class names to unique pixel values class_to_pixel = {name: i+1 for i, name in enumerate(sorted(class_mapping.keys()))} # Create a mapping of slice names to their QImage objects slice_map = {slice_name: qimage for slice_name, qimage in slices} for image_name, annotations in all_annotations.items(): # Skip if there are no annotations for this image/slice if not annotations: continue # Check if it's a slice (either in slice_map or has underscores and no file extension) is_slice = image_name in slice_map or ('_' in image_name and '.' not in image_name) if is_slice: qimage = slice_map.get(image_name) if qimage is None: # If the slice is not in slice_map, it might be a CZI slice or a TIFF slice matching_slices = [s for s in slices if s[0] == image_name] if matching_slices: qimage = matching_slices[0][1] else: # Check in image_slices for stack_slices in image_slices.values(): matching_slices = [s for s in stack_slices if s[0] == image_name] if matching_slices: qimage = matching_slices[0][1] break if qimage is None: print(f"No image data found for slice {image_name}, skipping") continue file_name_img = f"{image_name}.png" # Save the QImage as a file save_path = os.path.join(images_dir, file_name_img) if not os.path.exists(save_path): qimage.save(save_path) else: print(f"Image {file_name_img} already exists in the target directory. Skipping copy.") img_width, img_height = qimage.width(), qimage.height() else: # Check if the image_name exists in image_paths image_path = next((path for name, path in image_paths.items() if image_name in name), None) if not image_path: print(f"No image path found for {image_name}, skipping") continue if image_path.lower().endswith(('.tif', '.tiff', '.czi')): print(f"Skipping main tiff/czi file: {image_name}") continue file_name_img = image_name # Copy the image file dst_path = os.path.join(images_dir, file_name_img) if not os.path.exists(dst_path): shutil.copy2(image_path, dst_path) else: print(f"Image {file_name_img} already exists in the target directory. Skipping copy.") img = Image.open(image_path) img_width, img_height = img.size # Create a single mask for all classes semantic_mask = np.zeros((img_height, img_width), dtype=np.uint8) for class_name, class_annotations in annotations.items(): pixel_value = class_to_pixel[class_name] for ann in class_annotations: if 'segmentation' in ann: polygon = np.array(ann['segmentation']).reshape(-1, 2) rr, cc = skimage.draw.polygon(polygon[:, 1], polygon[:, 0], (img_height, img_width)) semantic_mask[rr, cc] = pixel_value elif 'bbox' in ann: x, y, w, h = map(int, ann['bbox']) semantic_mask[y:y+h, x:x+w] = pixel_value # Save semantic mask mask_filename = f"{os.path.splitext(file_name_img)[0]}_semantic_mask.png" mask_path = os.path.join(segmented_images_dir, mask_filename) Image.fromarray(semantic_mask).save(mask_path) # Create class mapping text file mapping_path = os.path.join(segmented_images_dir, 'class_pixel_mapping.txt') with open(mapping_path, 'w') as f: f.write("Pixel Value : Class Name\n") for class_name, pixel_value in class_to_pixel.items(): f.write(f"{pixel_value} : {class_name}\n") return output_dir def export_pascal_voc_bbox(all_annotations, class_mapping, image_paths, slices, image_slices, output_dir): # Create output directories images_dir = os.path.join(output_dir, 'images') annotations_dir = os.path.join(output_dir, 'Annotations') os.makedirs(images_dir, exist_ok=True) os.makedirs(annotations_dir, exist_ok=True) # Create a mapping of slice names to their QImage objects slice_map = {slice_name: qimage for slice_name, qimage in slices} for image_name, annotations in all_annotations.items(): # Skip if there are no annotations for this image/slice if not annotations: continue # Check if it's a slice (either in slice_map or has underscores and no file extension) is_slice = image_name in slice_map or ('_' in image_name and '.' not in image_name) if is_slice: qimage = slice_map.get(image_name) if qimage is None: # If the slice is not in slice_map, it might be a CZI slice or a TIFF slice matching_slices = [s for s in slices if s[0] == image_name] if matching_slices: qimage = matching_slices[0][1] else: # Check in image_slices for stack_slices in image_slices.values(): matching_slices = [s for s in stack_slices if s[0] == image_name] if matching_slices: qimage = matching_slices[0][1] break if qimage is None: print(f"No image data found for slice {image_name}, skipping") continue file_name_img = f"{image_name}.png" # Save the QImage as a file save_path = os.path.join(images_dir, file_name_img) if not os.path.exists(save_path): qimage.save(save_path) else: print(f"Image {file_name_img} already exists in the target directory. Skipping copy.") img_width, img_height = qimage.width(), qimage.height() else: # Check if the image_name exists in image_paths image_path = next((path for name, path in image_paths.items() if image_name in name), None) if not image_path: print(f"No image path found for {image_name}, skipping") continue if image_path.lower().endswith(('.tif', '.tiff', '.czi')): print(f"Skipping main tiff/czi file: {image_name}") continue file_name_img = image_name # Copy the image file dst_path = os.path.join(images_dir, file_name_img) if not os.path.exists(dst_path): shutil.copy2(image_path, dst_path) else: print(f"Image {file_name_img} already exists in the target directory. Skipping copy.") img = QImage(image_path) img_width, img_height = img.width(), img.height() # Create the XML structure root = ET.Element('annotation') ET.SubElement(root, 'folder').text = 'images' ET.SubElement(root, 'filename').text = file_name_img ET.SubElement(root, 'path').text = os.path.join('images', file_name_img) size = ET.SubElement(root, 'size') ET.SubElement(size, 'width').text = str(img_width) ET.SubElement(size, 'height').text = str(img_height) ET.SubElement(size, 'depth').text = '3' # Assuming RGB images ET.SubElement(root, 'segmented').text = '0' # Add object annotations for class_name, class_annotations in annotations.items(): for ann in class_annotations: obj = ET.SubElement(root, 'object') ET.SubElement(obj, 'name').text = class_name ET.SubElement(obj, 'pose').text = 'Unspecified' ET.SubElement(obj, 'truncated').text = '0' ET.SubElement(obj, 'difficult').text = '0' if 'bbox' in ann: x, y, w, h = ann['bbox'] bndbox = ET.SubElement(obj, 'bndbox') ET.SubElement(bndbox, 'xmin').text = str(int(x)) ET.SubElement(bndbox, 'ymin').text = str(int(y)) ET.SubElement(bndbox, 'xmax').text = str(int(x + w)) ET.SubElement(bndbox, 'ymax').text = str(int(y + h)) # Save the XML file xml_str = minidom.parseString(ET.tostring(root)).toprettyxml(indent=" ") xml_filename = os.path.splitext(file_name_img)[0] + '.xml' with open(os.path.join(annotations_dir, xml_filename), 'w') as f: f.write(xml_str) return output_dir def export_pascal_voc_both(all_annotations, class_mapping, image_paths, slices, image_slices, output_dir): # Create output directories images_dir = os.path.join(output_dir, 'images') annotations_dir = os.path.join(output_dir, 'Annotations') os.makedirs(images_dir, exist_ok=True) os.makedirs(annotations_dir, exist_ok=True) # Create a mapping of slice names to their QImage objects slice_map = {slice_name: qimage for slice_name, qimage in slices} for image_name, annotations in all_annotations.items(): # Skip if there are no annotations for this image/slice if not annotations: continue # Check if it's a slice (either in slice_map or has underscores and no file extension) is_slice = image_name in slice_map or ('_' in image_name and '.' not in image_name) if is_slice: qimage = slice_map.get(image_name) if qimage is None: # If the slice is not in slice_map, it might be a CZI slice or a TIFF slice matching_slices = [s for s in slices if s[0] == image_name] if matching_slices: qimage = matching_slices[0][1] else: # Check in image_slices for stack_slices in image_slices.values(): matching_slices = [s for s in stack_slices if s[0] == image_name] if matching_slices: qimage = matching_slices[0][1] break if qimage is None: print(f"No image data found for slice {image_name}, skipping") continue file_name_img = f"{image_name}.png" # Save the QImage as a file save_path = os.path.join(images_dir, file_name_img) if not os.path.exists(save_path): qimage.save(save_path) else: print(f"Image {file_name_img} already exists in the target directory. Skipping copy.") img_width, img_height = qimage.width(), qimage.height() else: # Check if the image_name exists in image_paths image_path = next((path for name, path in image_paths.items() if image_name in name), None) if not image_path: print(f"No image path found for {image_name}, skipping") continue if image_path.lower().endswith(('.tif', '.tiff', '.czi')): print(f"Skipping main tiff/czi file: {image_name}") continue file_name_img = image_name # Copy the image file dst_path = os.path.join(images_dir, file_name_img) if not os.path.exists(dst_path): shutil.copy2(image_path, dst_path) else: print(f"Image {file_name_img} already exists in the target directory. Skipping copy.") img = QImage(image_path) img_width, img_height = img.width(), img.height() # Create the XML structure root = ET.Element('annotation') ET.SubElement(root, 'folder').text = 'images' ET.SubElement(root, 'filename').text = file_name_img ET.SubElement(root, 'path').text = os.path.join('images', file_name_img) size = ET.SubElement(root, 'size') ET.SubElement(size, 'width').text = str(img_width) ET.SubElement(size, 'height').text = str(img_height) ET.SubElement(size, 'depth').text = '3' # Assuming RGB images ET.SubElement(root, 'segmented').text = '1' # Set to 1 if segmentation is included # Add object annotations for class_name, class_annotations in annotations.items(): for ann in class_annotations: obj = ET.SubElement(root, 'object') ET.SubElement(obj, 'name').text = class_name ET.SubElement(obj, 'pose').text = 'Unspecified' ET.SubElement(obj, 'truncated').text = '0' ET.SubElement(obj, 'difficult').text = '0' if 'bbox' in ann: x, y, w, h = ann['bbox'] bndbox = ET.SubElement(obj, 'bndbox') ET.SubElement(bndbox, 'xmin').text = str(int(x)) ET.SubElement(bndbox, 'ymin').text = str(int(y)) ET.SubElement(bndbox, 'xmax').text = str(int(x + w)) ET.SubElement(bndbox, 'ymax').text = str(int(y + h)) if 'segmentation' in ann: segmentation = ET.SubElement(obj, 'segmentation') ET.SubElement(segmentation, 'area').text = str(ann.get('area', 0)) # Convert polygon to a list of (x,y) tuples polygon = ann['segmentation'] points = [(polygon[i], polygon[i+1]) for i in range(0, len(polygon), 2)] # Create the polygon element polygon_elem = ET.SubElement(segmentation, 'polygon') for i, (x, y) in enumerate(points): point = ET.SubElement(polygon_elem, f'pt{i+1}') ET.SubElement(point, 'x').text = str(int(x)) ET.SubElement(point, 'y').text = str(int(y)) # Save the XML file xml_str = minidom.parseString(ET.tostring(root)).toprettyxml(indent=" ") xml_filename = os.path.splitext(file_name_img)[0] + '.xml' with open(os.path.join(annotations_dir, xml_filename), 'w') as f: f.write(xml_str) return output_dir ================================================ FILE: src/digitalsreeni_image_annotator/help_window.py ================================================ from PyQt5.QtWidgets import QDialog, QVBoxLayout, QTextBrowser from PyQt5.QtCore import Qt from .soft_dark_stylesheet import soft_dark_stylesheet from .default_stylesheet import default_stylesheet class HelpWindow(QDialog): def __init__(self, dark_mode=False, font_size=10): super().__init__() self.setWindowTitle("Help") self.setModal(False) # Make it non-modal self.setGeometry(100, 100, 800, 600) layout = QVBoxLayout() self.text_browser = QTextBrowser() self.text_browser.setOpenExternalLinks(True) layout.addWidget(self.text_browser) self.setLayout(layout) if dark_mode: self.setStyleSheet(soft_dark_stylesheet) else: self.setStyleSheet(default_stylesheet) self.font_size = font_size self.apply_font_size() self.load_help_content() def show_centered(self, parent): parent_geo = parent.geometry() self.move(parent_geo.center() - self.rect().center()) self.show() def apply_font_size(self): self.setStyleSheet(f"QWidget {{ font-size: {self.font_size}pt; }}") font = self.text_browser.font() font.setPointSize(self.font_size) self.text_browser.setFont(font) def load_help_content(self): help_text = """
Image Annotator is a user-friendly GUI tool designed for generating masks for image segmentation and object detection. It allows users to create, edit, and save annotations in various formats, including COCO-style JSON, YOLO v8, and Pascal VOC. Annotations can be defined using manual tools like the polygon tool or in a semi-automated way with the assistance of the Segment Anything Model (SAM-2) pre-trained model. The tool supports multi-dimensional images such as TIFF stacks and CZI files and provides dark mode and adjustable application font sizes for enhanced GUI experience.
The Tools menu provides access to various useful tools for dataset management and image processing. Each tool opens an intuitive GUI to guide you through the process:
If you encounter any issues or have suggestions for improvement, please open an issue on our GitHub repository or contact the development team.
""" self.text_browser.setHtml(help_text) ================================================ FILE: src/digitalsreeni_image_annotator/image_augmenter.py ================================================ import os import random import cv2 import numpy as np import json from PyQt5.QtWidgets import (QDialog, QVBoxLayout, QHBoxLayout, QPushButton, QFileDialog, QLabel, QMessageBox, QSpinBox, QCheckBox, QDoubleSpinBox, QProgressBar, QApplication) from PyQt5.QtCore import Qt class ImageAugmenterDialog(QDialog): def __init__(self, parent=None): super().__init__(parent) self.setWindowTitle("Image Augmenter") self.setGeometry(100, 100, 400, 600) self.setWindowFlags(self.windowFlags() | Qt.Window) self.setWindowModality(Qt.ApplicationModal) self.input_dir = "" self.output_dir = "" self.coco_file = "" self.coco_data = None self.initUI() def initUI(self): layout = QVBoxLayout() # Input directory selection input_layout = QHBoxLayout() self.input_label = QLabel("Input Directory: Not selected") input_button = QPushButton("Select Input Directory") input_button.clicked.connect(self.select_input_directory) input_layout.addWidget(self.input_label) input_layout.addWidget(input_button) layout.addLayout(input_layout) # Output directory selection output_layout = QHBoxLayout() self.output_label = QLabel("Output Directory: Not selected") output_button = QPushButton("Select Output Directory") output_button.clicked.connect(self.select_output_directory) output_layout.addWidget(self.output_label) output_layout.addWidget(output_button) layout.addLayout(output_layout) # Number of augmentations per image aug_count_layout = QHBoxLayout() aug_count_layout.addWidget(QLabel("Augmentations per image:")) self.aug_count_spin = QSpinBox() self.aug_count_spin.setRange(1, 100) self.aug_count_spin.setValue(5) aug_count_layout.addWidget(self.aug_count_spin) layout.addLayout(aug_count_layout) # Add COCO JSON annotation augmentation checkbox and file selection self.coco_check = QCheckBox("Augment COCO JSON annotations") self.coco_check.stateChanged.connect(self.toggle_elastic_deformation) layout.addWidget(self.coco_check) coco_layout = QHBoxLayout() self.coco_label = QLabel("COCO JSON File: Not selected") coco_button = QPushButton("Select COCO JSON") coco_button.clicked.connect(self.select_coco_json) coco_layout.addWidget(self.coco_label) coco_layout.addWidget(coco_button) layout.addLayout(coco_layout) # Transformations layout.addWidget(QLabel("Transformations:")) self.rotate_check = QCheckBox("Rotate") self.rotate_spin = QSpinBox() self.rotate_spin.setRange(-180, 180) self.rotate_spin.setValue(30) rotate_layout = QHBoxLayout() rotate_layout.addWidget(self.rotate_check) rotate_layout.addWidget(QLabel("Max degrees:")) rotate_layout.addWidget(self.rotate_spin) layout.addLayout(rotate_layout) self.zoom_check = QCheckBox("Zoom") self.zoom_spin = QDoubleSpinBox() self.zoom_spin.setRange(0.1, 2.0) self.zoom_spin.setValue(0.2) self.zoom_spin.setSingleStep(0.1) zoom_layout = QHBoxLayout() zoom_layout.addWidget(self.zoom_check) zoom_layout.addWidget(QLabel("Scale factor:")) zoom_layout.addWidget(self.zoom_spin) layout.addLayout(zoom_layout) self.blur_check = QCheckBox("Gaussian Blur") layout.addWidget(self.blur_check) self.brightness_contrast_check = QCheckBox("Random Brightness and Contrast") layout.addWidget(self.brightness_contrast_check) self.sharpen_check = QCheckBox("Sharpen") layout.addWidget(self.sharpen_check) # Flip transformation flip_layout = QHBoxLayout() self.flip_check = QCheckBox("Flip") flip_layout.addWidget(self.flip_check) self.flip_horizontal_check = QCheckBox("Horizontal") self.flip_vertical_check = QCheckBox("Vertical") flip_layout.addWidget(self.flip_horizontal_check) flip_layout.addWidget(self.flip_vertical_check) self.flip_horizontal_check.stateChanged.connect(self.update_flip_check) self.flip_vertical_check.stateChanged.connect(self.update_flip_check) layout.addLayout(flip_layout) # Elastic Deformation self.elastic_check = QCheckBox("Elastic Deformation") layout.addWidget(self.elastic_check) elastic_layout = QHBoxLayout() elastic_layout.addWidget(self.elastic_check) elastic_layout.addWidget(QLabel("Alpha:")) self.elastic_alpha_spin = QSpinBox() self.elastic_alpha_spin.setRange(1, 1000) self.elastic_alpha_spin.setValue(500) elastic_layout.addWidget(self.elastic_alpha_spin) elastic_layout.addWidget(QLabel("Sigma:")) self.elastic_sigma_spin = QSpinBox() self.elastic_sigma_spin.setRange(1, 100) self.elastic_sigma_spin.setValue(20) elastic_layout.addWidget(self.elastic_sigma_spin) layout.addLayout(elastic_layout) # Grayscale Conversion self.grayscale_check = QCheckBox("Convert to Grayscale") layout.addWidget(self.grayscale_check) # Histogram Equalization self.hist_equalize_check = QCheckBox("Histogram Equalization") layout.addWidget(self.hist_equalize_check) # Augment button self.augment_button = QPushButton("Start Augmentation") self.augment_button.clicked.connect(self.start_augmentation) layout.addWidget(self.augment_button) # Progress bar self.progress_bar = QProgressBar() layout.addWidget(self.progress_bar) self.setLayout(layout) def select_input_directory(self): self.input_dir = QFileDialog.getExistingDirectory(self, "Select Input Directory") if self.input_dir: self.input_label.setText(f"Input Directory: {os.path.basename(self.input_dir)}") def select_output_directory(self): self.output_dir = QFileDialog.getExistingDirectory(self, "Select Output Directory") if self.output_dir: self.output_label.setText(f"Output Directory: {os.path.basename(self.output_dir)}") def update_flip_check(self, state): if self.flip_horizontal_check.isChecked() or self.flip_vertical_check.isChecked(): self.flip_check.setChecked(True) else: self.flip_check.setChecked(False) def select_coco_json(self): self.coco_file, _ = QFileDialog.getOpenFileName(self, "Select COCO JSON File", "", "JSON Files (*.json)") if self.coco_file: self.coco_label.setText(f"COCO JSON File: {os.path.basename(self.coco_file)}") with open(self.coco_file, 'r') as f: self.coco_data = json.load(f) self.coco_check.setChecked(True) # Automatically check the box when a file is loaded def toggle_elastic_deformation(self, state): if state == Qt.Checked: self.elastic_check.setChecked(False) self.elastic_check.setEnabled(False) else: self.elastic_check.setEnabled(True) def start_augmentation(self): if not self.input_dir or not self.output_dir: QMessageBox.warning(self, "Missing Directory", "Please select both input and output directories.") return if self.coco_check.isChecked() and not self.coco_file: QMessageBox.warning(self, "Missing COCO JSON", "Please select a COCO JSON file for annotation augmentation.") return # Create 'images' subdirectory in the output directory images_output_dir = os.path.join(self.output_dir, "images") os.makedirs(images_output_dir, exist_ok=True) image_files = [f for f in os.listdir(self.input_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tif', '.tiff'))] total_augmentations = len(image_files) * self.aug_count_spin.value() self.progress_bar.setMaximum(total_augmentations) self.progress_bar.setValue(0) augmented_coco_data = { "images": [], "annotations": [], "categories": self.coco_data["categories"] if self.coco_data else [] } next_image_id = 1 next_annotation_id = 1 for i, image_file in enumerate(image_files): input_path = os.path.join(self.input_dir, image_file) image = cv2.imread(input_path, cv2.IMREAD_UNCHANGED) if image is None: print(f"Error loading image: {input_path}") continue # Determine image type and bit depth is_color = len(image.shape) == 3 and image.shape[2] == 3 bit_depth = image.dtype original_annotations = [] if self.coco_check.isChecked(): original_annotations = [ann for ann in self.coco_data["annotations"] if any(img['file_name'] == image_file and img['id'] == ann['image_id'] for img in self.coco_data["images"])] for j in range(self.aug_count_spin.value()): try: augmented, transform_params = self.apply_random_augmentation(image, include_annotations=self.coco_check.isChecked()) # Ensure the augmented image has the same properties as the input if not is_color and len(augmented.shape) == 3: augmented = cv2.cvtColor(augmented, cv2.COLOR_BGR2GRAY) elif is_color and len(augmented.shape) == 2: augmented = cv2.cvtColor(augmented, cv2.COLOR_GRAY2BGR) augmented = augmented.astype(bit_depth) output_filename = f"{os.path.splitext(image_file)[0]}_aug_{j+1}{os.path.splitext(image_file)[1]}" output_path = os.path.join(images_output_dir, output_filename) cv2.imwrite(output_path, augmented) if self.coco_check.isChecked(): augmented_coco_data["images"].append({ "id": next_image_id, "file_name": output_filename, "height": augmented.shape[0], "width": augmented.shape[1] }) for ann in original_annotations: augmented_ann = self.augment_annotation(ann, transform_params, augmented.shape[:2]) augmented_ann["id"] = next_annotation_id augmented_ann["image_id"] = next_image_id augmented_coco_data["annotations"].append(augmented_ann) next_annotation_id += 1 next_image_id += 1 self.progress_bar.setValue(i * self.aug_count_spin.value() + j + 1) QApplication.processEvents() except Exception as e: print(f"Error processing {image_file} (augmentation {j+1}): {str(e)}") continue # Skip this augmentation and continue with the next if self.coco_check.isChecked(): output_coco_path = os.path.join(self.output_dir, "augmented_annotations.json") with open(output_coco_path, 'w') as f: json.dump(augmented_coco_data, f, indent=2) QMessageBox.information(self, "Augmentation Complete", "Image and annotation augmentation has been completed successfully.") def apply_random_augmentation(self, image, include_annotations=False): augmentations = [] if self.rotate_check.isChecked(): augmentations.append(self.rotate_image) if self.zoom_check.isChecked(): augmentations.append(self.zoom_image) if self.blur_check.isChecked(): augmentations.append(self.blur_image) if self.brightness_contrast_check.isChecked(): augmentations.append(self.adjust_brightness_contrast) if self.sharpen_check.isChecked(): augmentations.append(self.sharpen_image) if self.flip_check.isChecked(): augmentations.append(self.flip_image) if self.elastic_check.isChecked() and not include_annotations: augmentations.append(self.elastic_transform) if self.grayscale_check.isChecked(): augmentations.append(self.convert_to_grayscale) if self.hist_equalize_check.isChecked(): augmentations.append(self.apply_histogram_equalization) if not augmentations: return image, {} aug_func = random.choice(augmentations) return aug_func(image) def rotate_image(self, image): angle = random.uniform(-self.rotate_spin.value(), self.rotate_spin.value()) h, w = image.shape[:2] center = (w / 2, h / 2) M = cv2.getRotationMatrix2D(center, -angle, 1.0) # Negative angle for clockwise rotation rotated = cv2.warpAffine(image, M, (w, h), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_REFLECT) return rotated, {"type": "rotate", "angle": angle, "center": center, "matrix": M} def zoom_image(self, image): scale = random.uniform(1, 1 + self.zoom_spin.value()) h, w = image.shape[:2] center = (w / 2, h / 2) M = cv2.getRotationMatrix2D(center, 0, scale) zoomed = cv2.warpAffine(image, M, (w, h), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_REFLECT) return zoomed, {"type": "zoom", "scale": scale, "center": center, "matrix": M} def blur_image(self, image): kernel_size = random.choice([3, 5, 7]) blurred = cv2.GaussianBlur(image, (kernel_size, kernel_size), 0) return blurred, {"type": "blur", "kernel_size": kernel_size} def adjust_brightness_contrast(self, image): alpha = random.uniform(0.5, 1.5) # Contrast control beta = random.uniform(-30, 30) # Brightness control adjusted = cv2.convertScaleAbs(image, alpha=alpha, beta=beta) return adjusted, {"type": "brightness_contrast", "alpha": alpha, "beta": beta} def sharpen_image(self, image): kernel = np.array([[-1,-1,-1], [-1,9,-1], [-1,-1,-1]]) sharpened = cv2.filter2D(image, -1, kernel) return sharpened, {"type": "sharpen"} def flip_image(self, image): flip_options = [] if self.flip_horizontal_check.isChecked(): flip_options.append(1) # Horizontal flip if self.flip_vertical_check.isChecked(): flip_options.append(0) # Vertical flip if self.flip_horizontal_check.isChecked() and self.flip_vertical_check.isChecked(): flip_options.append(-1) # Both horizontal and vertical if not flip_options: return image, {"type": "flip", "flip_code": None} flip_code = random.choice(flip_options) flipped = cv2.flip(image, flip_code) return flipped, {"type": "flip", "flip_code": flip_code} def elastic_transform(self, image): alpha = self.elastic_alpha_spin.value() sigma = self.elastic_sigma_spin.value() shape = image.shape[:2] random_state = np.random.RandomState(None) dx = random_state.rand(*shape) * 2 - 1 dy = random_state.rand(*shape) * 2 - 1 dx = cv2.GaussianBlur(dx, (0, 0), sigma) * alpha dy = cv2.GaussianBlur(dy, (0, 0), sigma) * alpha x, y = np.meshgrid(np.arange(shape[1]), np.arange(shape[0])) distorted_x = x + dx distorted_y = y + dy transformed = cv2.remap(image, distorted_x.astype(np.float32), distorted_y.astype(np.float32), interpolation=cv2.INTER_LINEAR, borderMode=cv2.BORDER_REFLECT_101) return transformed, {"type": "elastic", "dx": dx, "dy": dy, "shape": shape} def convert_to_grayscale(self, image): if len(image.shape) == 2: return image, {"type": "grayscale"} # Already grayscale gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) gray_3channel = cv2.cvtColor(gray, cv2.COLOR_GRAY2BGR) # Convert back to 3 channels return gray_3channel, {"type": "grayscale"} def apply_histogram_equalization(self, image): def equalize_8bit(img): return cv2.equalizeHist(img) def equalize_16bit(img): # Equalize 16-bit image hist, bins = np.histogram(img.flatten(), 65536, [0, 65536]) cdf = hist.cumsum() cdf_normalized = cdf * 65535 / cdf[-1] equalized = np.interp(img.flatten(), bins[:-1], cdf_normalized).reshape(img.shape) return equalized.astype(np.uint16) if len(image.shape) == 2: # Grayscale image if image.dtype == np.uint8: equalized = equalize_8bit(image) elif image.dtype == np.uint16: equalized = equalize_16bit(image) else: raise ValueError(f"Unsupported image dtype: {image.dtype}") return equalized, {"type": "histogram_equalization", "mode": "grayscale"} else: # Color image # Convert to YUV color space yuv = cv2.cvtColor(image, cv2.COLOR_BGR2YUV) # Equalize the Y channel if image.dtype == np.uint8: yuv[:,:,0] = equalize_8bit(yuv[:,:,0]) elif image.dtype == np.uint16: yuv[:,:,0] = equalize_16bit(yuv[:,:,0]) else: raise ValueError(f"Unsupported image dtype: {image.dtype}") # Convert back to BGR color space equalized = cv2.cvtColor(yuv, cv2.COLOR_YUV2BGR) return equalized, {"type": "histogram_equalization", "mode": "color"} def augment_annotation(self, annotation, transform_params, image_shape): augmented_ann = annotation.copy() if transform_params["type"] == "rotate": angle = transform_params["angle"] center = transform_params["center"] matrix = transform_params["matrix"] augmented_ann["segmentation"] = [self.rotate_polygon(annotation["segmentation"][0], angle, center, matrix)] elif transform_params["type"] == "zoom": scale = transform_params["scale"] center = transform_params["center"] matrix = transform_params["matrix"] augmented_ann["segmentation"] = [self.scale_polygon(annotation["segmentation"][0], scale, center, matrix)] elif transform_params["type"] == "flip": flip_code = transform_params["flip_code"] if flip_code is not None: augmented_ann["segmentation"] = [self.flip_polygon(annotation["segmentation"][0], flip_code, image_shape)] elif transform_params["type"] == "elastic": dx = transform_params["dx"] dy = transform_params["dy"] shape = transform_params["shape"] augmented_ann["segmentation"] = [self.elastic_transform_polygon(annotation["segmentation"][0], dx, dy, shape)] # Recalculate bbox and area for all transformation types if "segmentation" in augmented_ann and augmented_ann["segmentation"]: augmented_ann["bbox"] = self.get_bbox_from_polygon(augmented_ann["segmentation"][0]) augmented_ann["area"] = int(self.calculate_polygon_area(augmented_ann["segmentation"][0])) return augmented_ann def calculate_polygon_area(self, polygon): points = np.array(polygon).reshape(-1, 2) return 0.5 * np.abs(np.dot(points[:, 0], np.roll(points[:, 1], 1)) - np.dot(points[:, 1], np.roll(points[:, 0], 1))) def rotate_polygon(self, polygon, angle, center, matrix): points = np.array(polygon).reshape(-1, 2) ones = np.ones(shape=(len(points), 1)) points_ones = np.hstack([points, ones]) transformed_points = matrix.dot(points_ones.T).T return np.round(transformed_points).astype(int).flatten().tolist() def scale_polygon(self, polygon, scale, center, matrix): points = np.array(polygon).reshape(-1, 2) ones = np.ones(shape=(len(points), 1)) points_ones = np.hstack([points, ones]) transformed_points = matrix.dot(points_ones.T).T return np.round(transformed_points).astype(int).flatten().tolist() def flip_polygon(self, polygon, flip_code, image_shape): points = np.array(polygon).reshape(-1, 2) if flip_code == 0: # Vertical flip points[:, 1] = image_shape[0] - points[:, 1] elif flip_code == 1: # Horizontal flip points[:, 0] = image_shape[1] - points[:, 0] elif flip_code == -1: # Both points[:, 0] = image_shape[1] - points[:, 0] points[:, 1] = image_shape[0] - points[:, 1] return np.round(points).astype(int).flatten().tolist() def get_bbox_from_polygon(self, polygon): points = np.array(polygon).reshape(-1, 2) x_min, y_min = np.min(points, axis=0) x_max, y_max = np.max(points, axis=0) return [int(x_min), int(y_min), int(x_max - x_min), int(y_max - y_min)] def show_centered(self, parent): parent_geo = parent.geometry() self.move(parent_geo.center() - self.rect().center()) self.show() def show_image_augmenter(parent): dialog = ImageAugmenterDialog(parent) dialog.show_centered(parent) return dialog ================================================ FILE: src/digitalsreeni_image_annotator/image_label.py ================================================ """ ImageLabel module for the Image Annotator application. This module contains the ImageLabel class, which is responsible for displaying the image and handling annotation interactions. @DigitalSreeni Dr. Sreenivas Bhattiprolu """ from PyQt5.QtWidgets import QLabel, QApplication, QMessageBox from PyQt5.QtGui import (QPainter, QPen, QColor, QFont, QPolygonF, QBrush, QPolygon, QPixmap, QImage, QWheelEvent, QMouseEvent, QKeyEvent) from PyQt5.QtCore import Qt, QPoint, QPointF, QRectF, QSize from PIL import Image import os import warnings import cv2 import numpy as np warnings.filterwarnings("ignore", category=UserWarning) class ImageLabel(QLabel): """ A custom QLabel for displaying images and handling annotations. """ def __init__(self, parent=None): super().__init__(parent) self.annotations = {} self.current_annotation = [] self.temp_point = None self.current_tool = None self.zoom_factor = 1.0 self.class_colors = {} self.class_visibility = {} self.start_point = None self.end_point = None self.highlighted_annotations = [] self.setMouseTracking(True) self.setFocusPolicy(Qt.StrongFocus) self.original_pixmap = None self.scaled_pixmap = None self.pan_start_pos = None self.main_window = None self.offset_x = 0 self.offset_y = 0 self.drawing_polygon = False self.editing_polygon = None self.editing_point_index = None self.hover_point_index = None self.fill_opacity = 0.3 self.drawing_rectangle = False self.current_rectangle = None self.bit_depth = None self.image_path = None self.dark_mode = False self.paint_mask = None self.eraser_mask = None self.temp_paint_mask = None self.is_painting = False self.temp_eraser_mask = None self.is_erasing = False self.cursor_pos = None #SAM self.sam_magic_wand_active = False self.sam_bbox = None self.drawing_sam_bbox = False self.temp_sam_prediction = None self.temp_annotations = [] def set_main_window(self, main_window): self.main_window = main_window def set_dark_mode(self, is_dark): self.dark_mode = is_dark self.update() def setPixmap(self, pixmap): """Set the pixmap and update the scaled version.""" if isinstance(pixmap, QImage): pixmap = QPixmap.fromImage(pixmap) self.original_pixmap = pixmap self.update_scaled_pixmap() def detect_bit_depth(self): """Detect and store the actual image bit depth using PIL.""" if self.image_path and os.path.exists(self.image_path): with Image.open(self.image_path) as img: if img.mode == '1': self.bit_depth = 1 elif img.mode == 'L': self.bit_depth = 8 elif img.mode == 'I;16': self.bit_depth = 16 elif img.mode in ['RGB', 'HSV']: self.bit_depth = 24 elif img.mode in ['RGBA', 'CMYK']: self.bit_depth = 32 else: self.bit_depth = img.bits if self.main_window: self.main_window.update_image_info() def update_scaled_pixmap(self): if self.original_pixmap and not self.original_pixmap.isNull(): scaled_size = self.original_pixmap.size() * self.zoom_factor self.scaled_pixmap = self.original_pixmap.scaled( scaled_size.width(), scaled_size.height(), Qt.KeepAspectRatio, Qt.SmoothTransformation ) super().setPixmap(self.scaled_pixmap) self.setMinimumSize(self.scaled_pixmap.size()) self.update_offset() else: self.scaled_pixmap = None super().setPixmap(QPixmap()) self.setMinimumSize(QSize(0, 0)) def update_offset(self): """Update the offset for centered image display.""" if self.scaled_pixmap: self.offset_x = int((self.width() - self.scaled_pixmap.width()) / 2) self.offset_y = int((self.height() - self.scaled_pixmap.height()) / 2) def reset_annotation_state(self): """Reset the annotation state.""" self.temp_point = None self.start_point = None self.end_point = None def clear_current_annotation(self): """Clear the current annotation.""" self.current_annotation = [] def resizeEvent(self, event): """Handle resize events.""" super().resizeEvent(event) self.update_offset() def start_painting(self, pos): if self.temp_paint_mask is None: self.temp_paint_mask = np.zeros((self.original_pixmap.height(), self.original_pixmap.width()), dtype=np.uint8) self.is_painting = True self.continue_painting(pos) def continue_painting(self, pos): if not self.is_painting: return brush_size = self.main_window.paint_brush_size cv2.circle(self.temp_paint_mask, (int(pos[0]), int(pos[1])), brush_size, 255, -1) self.update() def finish_painting(self): if not self.is_painting: return self.is_painting = False # Don't commit the annotation yet, just keep the temp_paint_mask def commit_paint_annotation(self): if self.temp_paint_mask is not None and self.main_window.current_class: class_name = self.main_window.current_class contours, _ = cv2.findContours(self.temp_paint_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) for contour in contours: if cv2.contourArea(contour) > 10: # Minimum area threshold segmentation = contour.flatten().tolist() new_annotation = { "segmentation": segmentation, "category_id": self.main_window.class_mapping[class_name], "category_name": class_name, } self.annotations.setdefault(class_name, []).append(new_annotation) self.main_window.add_annotation_to_list(new_annotation) self.temp_paint_mask = None self.main_window.save_current_annotations() self.main_window.update_slice_list_colors() self.update() def discard_paint_annotation(self): self.temp_paint_mask = None self.update() def start_erasing(self, pos): if self.temp_eraser_mask is None: self.temp_eraser_mask = np.zeros((self.original_pixmap.height(), self.original_pixmap.width()), dtype=np.uint8) self.is_erasing = True self.continue_erasing(pos) def continue_erasing(self, pos): if not self.is_erasing: return eraser_size = self.main_window.eraser_size cv2.circle(self.temp_eraser_mask, (int(pos[0]), int(pos[1])), eraser_size, 255, -1) self.update() def finish_erasing(self): if not self.is_erasing: return self.is_erasing = False # Don't commit the eraser changes yet, just keep the temp_eraser_mask def commit_eraser_changes(self): if self.temp_eraser_mask is not None: eraser_mask = self.temp_eraser_mask.astype(bool) current_name = self.main_window.current_slice or self.main_window.image_file_name annotations_changed = False for class_name, annotations in self.annotations.items(): updated_annotations = [] max_number = max([ann.get('number', 0) for ann in annotations] + [0]) for annotation in annotations: if "segmentation" in annotation and annotation["segmentation"] is not None: points = np.array(annotation["segmentation"]).reshape(-1, 2).astype(int) mask = np.zeros_like(self.temp_eraser_mask) cv2.fillPoly(mask, [points], 255) mask = mask.astype(bool) mask[eraser_mask] = False contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) for i, contour in enumerate(contours): if cv2.contourArea(contour) > 10: # Minimum area threshold new_segmentation = contour.flatten().tolist() new_annotation = annotation.copy() new_annotation["segmentation"] = new_segmentation if i == 0: new_annotation["number"] = annotation.get("number", max_number + 1) else: max_number += 1 new_annotation["number"] = max_number updated_annotations.append(new_annotation) if len(contours) > 1: annotations_changed = True else: updated_annotations.append(annotation) self.annotations[class_name] = updated_annotations self.temp_eraser_mask = None # Update the all_annotations dictionary in the main window self.main_window.all_annotations[current_name] = self.annotations # Call update_annotation_list directly self.main_window.update_annotation_list() self.main_window.save_current_annotations() self.main_window.update_slice_list_colors() self.update() #print(f"Eraser changes committed. Annotations changed: {annotations_changed}") #print(f"Current annotations: {self.annotations}") def discard_eraser_changes(self): self.temp_eraser_mask = None self.update() def paintEvent(self, event): super().paintEvent(event) if self.scaled_pixmap: painter = QPainter(self) painter.setRenderHint(QPainter.Antialiasing) # Draw the image painter.drawPixmap(int(self.offset_x), int(self.offset_y), self.scaled_pixmap) # Draw annotations self.draw_annotations(painter) # Draw other elements if self.editing_polygon: self.draw_editing_polygon(painter) if self.drawing_rectangle and self.current_rectangle: self.draw_current_rectangle(painter) if self.sam_magic_wand_active and self.sam_bbox: self.draw_sam_bbox(painter) # Draw temporary paint mask if self.temp_paint_mask is not None: self.draw_temp_paint_mask(painter) # Draw temporary eraser mask if self.temp_eraser_mask is not None: self.draw_temp_eraser_mask(painter) # Draw brush/eraser size indicator self.draw_tool_size_indicator(painter) # Draw temporary YOLO predictions if self.temp_annotations: self.draw_temp_annotations(painter) painter.end() def draw_temp_annotations(self, painter): painter.save() painter.translate(self.offset_x, self.offset_y) painter.scale(self.zoom_factor, self.zoom_factor) for annotation in self.temp_annotations: color = QColor(255, 165, 0, 128) # Semi-transparent orange painter.setPen(QPen(color, 2 / self.zoom_factor, Qt.DashLine)) painter.setBrush(QBrush(color)) if "bbox" in annotation: x, y, w, h = annotation["bbox"] painter.drawRect(QRectF(x, y, w, h)) elif "segmentation" in annotation: points = [QPointF(float(x), float(y)) for x, y in zip(annotation["segmentation"][0::2], annotation["segmentation"][1::2])] painter.drawPolygon(QPolygonF(points)) # Draw label and score painter.setFont(QFont("Arial", int(12 / self.zoom_factor))) label = f"{annotation['category_name']} {annotation['score']:.2f}" if "bbox" in annotation: x, y, _, _ = annotation["bbox"] painter.drawText(QPointF(x, y - 5), label) elif "segmentation" in annotation: centroid = self.calculate_centroid(points) if centroid: painter.drawText(centroid, label) painter.restore() def accept_temp_annotations(self): for annotation in self.temp_annotations: class_name = annotation['category_name'] # Check if the class exists, if not, add it if class_name not in self.main_window.class_mapping: self.main_window.add_class(class_name) if class_name not in self.annotations: self.annotations[class_name] = [] del annotation['temp'] del annotation['score'] # Remove the score as it's not needed in the final annotation self.annotations[class_name].append(annotation) self.main_window.add_annotation_to_list(annotation) self.temp_annotations.clear() self.main_window.save_current_annotations() self.main_window.update_slice_list_colors() self.update() def discard_temp_annotations(self): self.temp_annotations.clear() self.update() def draw_temp_paint_mask(self, painter): if self.temp_paint_mask is not None: painter.save() painter.translate(self.offset_x, self.offset_y) painter.scale(self.zoom_factor, self.zoom_factor) mask_image = QImage(self.temp_paint_mask.data, self.temp_paint_mask.shape[1], self.temp_paint_mask.shape[0], self.temp_paint_mask.shape[1], QImage.Format_Grayscale8) mask_pixmap = QPixmap.fromImage(mask_image) painter.setOpacity(0.5) painter.drawPixmap(0, 0, mask_pixmap) painter.setOpacity(1.0) painter.restore() def draw_temp_eraser_mask(self, painter): if self.temp_eraser_mask is not None: painter.save() painter.translate(self.offset_x, self.offset_y) painter.scale(self.zoom_factor, self.zoom_factor) mask_image = QImage(self.temp_eraser_mask.data, self.temp_eraser_mask.shape[1], self.temp_eraser_mask.shape[0], self.temp_eraser_mask.shape[1], QImage.Format_Grayscale8) mask_pixmap = QPixmap.fromImage(mask_image) painter.setOpacity(0.5) painter.drawPixmap(0, 0, mask_pixmap) painter.setOpacity(1.0) painter.restore() def draw_tool_size_indicator(self, painter): if self.current_tool in ["paint_brush", "eraser"] and hasattr(self, 'cursor_pos'): painter.save() painter.translate(self.offset_x, self.offset_y) painter.scale(self.zoom_factor, self.zoom_factor) if self.current_tool == "paint_brush": size = self.main_window.paint_brush_size color = QColor(255, 0, 0, 128) # Semi-transparent red else: # eraser size = self.main_window.eraser_size color = QColor(0, 0, 255, 128) # Semi-transparent blue # Draw filled circle with lower opacity painter.setOpacity(0.3) painter.setPen(Qt.NoPen) painter.setBrush(color) painter.drawEllipse(QPointF(self.cursor_pos[0], self.cursor_pos[1]), size, size) # Draw circle outline with full opacity painter.setOpacity(1.0) painter.setPen(QPen(color.darker(150), 1 / self.zoom_factor, Qt.SolidLine)) painter.setBrush(Qt.NoBrush) painter.drawEllipse(QPointF(self.cursor_pos[0], self.cursor_pos[1]), size, size) # Draw size text # Reset the transform to ensure text is drawn at screen coordinates painter.resetTransform() font = QFont() font.setPointSize(10) painter.setFont(font) painter.setPen(QPen(Qt.black)) # Use black color for better visibility # Convert cursor position back to screen coordinates screen_x = self.cursor_pos[0] * self.zoom_factor + self.offset_x screen_y = self.cursor_pos[1] * self.zoom_factor + self.offset_y # Position text above the circle text_rect = QRectF(screen_x + (size * self.zoom_factor), screen_y - (size * self.zoom_factor), 100, 20) text = f"Size: {size}" painter.drawText(text_rect, Qt.AlignLeft | Qt.AlignVCenter, text) painter.restore() def draw_paint_mask(self, painter): if self.paint_mask is not None: mask_image = QImage(self.paint_mask.data, self.paint_mask.shape[1], self.paint_mask.shape[0], self.paint_mask.shape[1], QImage.Format_Grayscale8) mask_pixmap = QPixmap.fromImage(mask_image) painter.setOpacity(0.5) painter.drawPixmap(self.offset_x, self.offset_y, mask_pixmap.scaled(self.scaled_pixmap.size())) painter.setOpacity(1.0) def draw_eraser_mask(self, painter): if self.eraser_mask is not None: mask_image = QImage(self.eraser_mask.data, self.eraser_mask.shape[1], self.eraser_mask.shape[0], self.eraser_mask.shape[1], QImage.Format_Grayscale8) mask_pixmap = QPixmap.fromImage(mask_image) painter.setOpacity(0.5) painter.drawPixmap(self.offset_x, self.offset_y, mask_pixmap.scaled(self.scaled_pixmap.size())) painter.setOpacity(1.0) def draw_sam_bbox(self, painter): painter.save() painter.translate(self.offset_x, self.offset_y) painter.scale(self.zoom_factor, self.zoom_factor) painter.setPen(QPen(Qt.red, 2 / self.zoom_factor, Qt.SolidLine)) x1, y1, x2, y2 = self.sam_bbox painter.drawRect(QRectF(min(x1, x2), min(y1, y2), abs(x2 - x1), abs(y2 - y1))) painter.restore() def clear_temp_sam_prediction(self): self.temp_sam_prediction = None self.update() def check_unsaved_changes(self): if self.temp_paint_mask is not None or self.temp_eraser_mask is not None: reply = QMessageBox.question( self.main_window, 'Unsaved Changes', "You have unsaved changes. Do you want to save them?", QMessageBox.Yes | QMessageBox.No | QMessageBox.Cancel ) if reply == QMessageBox.Yes: if self.temp_paint_mask is not None: self.commit_paint_annotation() if self.temp_eraser_mask is not None: self.commit_eraser_changes() return True elif reply == QMessageBox.No: self.discard_paint_annotation() self.discard_eraser_changes() return True else: # Cancel return False return True # No unsaved changes def clear(self): super().clear() self.annotations.clear() self.current_annotation.clear() self.temp_point = None self.current_tool = None self.start_point = None self.end_point = None self.highlighted_annotations.clear() self.original_pixmap = None self.scaled_pixmap = None self.editing_polygon = None self.editing_point_index = None self.hover_point_index = None self.current_rectangle = None self.sam_bbox = None self.temp_sam_prediction = None self.update() def set_class_visibility(self, class_name, is_visible): self.class_visibility[class_name] = is_visible def draw_annotations(self, painter): """Draw all annotations on the image.""" if not self.original_pixmap: return painter.save() painter.translate(self.offset_x, self.offset_y) painter.scale(self.zoom_factor, self.zoom_factor) for class_name, class_annotations in self.annotations.items(): if not self.main_window.is_class_visible(class_name): continue color = self.class_colors.get(class_name, QColor(Qt.white)) for annotation in class_annotations: if annotation in self.highlighted_annotations: border_color = Qt.red fill_color = QColor(Qt.red) else: border_color = color fill_color = QColor(color) fill_color.setAlphaF(self.fill_opacity) text_color = Qt.white if self.dark_mode else Qt.black painter.setPen(QPen(border_color, 2 / self.zoom_factor, Qt.SolidLine)) painter.setBrush(QBrush(fill_color)) if "segmentation" in annotation and annotation["segmentation"] is not None: segmentation = annotation["segmentation"] if isinstance(segmentation, list) and len(segmentation) > 0: if isinstance(segmentation[0], list): # Multiple polygons for polygon in segmentation: points = [QPointF(float(x), float(y)) for x, y in zip(polygon[0::2], polygon[1::2])] if points: painter.drawPolygon(QPolygonF(points)) else: # Single polygon points = [QPointF(float(x), float(y)) for x, y in zip(segmentation[0::2], segmentation[1::2])] if points: painter.drawPolygon(QPolygonF(points)) # Draw centroid and label if points: centroid = self.calculate_centroid(points) if centroid: painter.setFont(QFont("Arial", int(12 / self.zoom_factor))) painter.setPen(QPen(text_color, 2 / self.zoom_factor, Qt.SolidLine)) painter.drawText(centroid, f"{class_name} {annotation.get('number', '')}") elif "bbox" in annotation: x, y, width, height = annotation["bbox"] painter.drawRect(QRectF(x, y, width, height)) painter.setPen(QPen(text_color, 2 / self.zoom_factor, Qt.SolidLine)) painter.drawText(QPointF(x, y), f"{class_name} {annotation.get('number', '')}") if self.current_annotation: painter.setPen(QPen(Qt.red, 2 / self.zoom_factor, Qt.SolidLine)) points = [QPointF(float(x), float(y)) for x, y in self.current_annotation] if len(points) > 1: painter.drawPolyline(QPolygonF(points)) for point in points: painter.drawEllipse(point, 5 / self.zoom_factor, 5 / self.zoom_factor) if self.temp_point: painter.drawLine(points[-1], QPointF(float(self.temp_point[0]), float(self.temp_point[1]))) # Draw temporary SAM prediction if self.temp_sam_prediction: temp_color = QColor(255, 165, 0, 128) # Semi-transparent orange painter.setPen(QPen(temp_color, 2 / self.zoom_factor, Qt.DashLine)) painter.setBrush(QBrush(temp_color)) segmentation = self.temp_sam_prediction["segmentation"] points = [QPointF(float(x), float(y)) for x, y in zip(segmentation[0::2], segmentation[1::2])] if points: painter.drawPolygon(QPolygonF(points)) centroid = self.calculate_centroid(points) if centroid: painter.setFont(QFont("Arial", int(12 / self.zoom_factor))) painter.drawText(centroid, f"SAM: {self.temp_sam_prediction['score']:.2f}") painter.restore() def draw_current_rectangle(self, painter): """Draw the current rectangle being created.""" if not self.current_rectangle: return painter.save() painter.translate(self.offset_x, self.offset_y) painter.scale(self.zoom_factor, self.zoom_factor) x1, y1, x2, y2 = self.current_rectangle color = self.class_colors.get(self.main_window.current_class, QColor(Qt.red)) painter.setPen(QPen(color, 2 / self.zoom_factor, Qt.SolidLine)) painter.drawRect(QRectF(float(x1), float(y1), float(x2 - x1), float(y2 - y1))) painter.restore() def get_rectangle_from_points(self): """Get rectangle coordinates from start and end points.""" if not self.start_point or not self.end_point: return None x1, y1 = self.start_point x2, y2 = self.end_point return [min(x1, x2), min(y1, y2), max(x1, x2), max(y1, y2)] def draw_editing_polygon(self, painter): """Draw the polygon being edited.""" painter.save() painter.translate(self.offset_x, self.offset_y) painter.scale(self.zoom_factor, self.zoom_factor) points = [QPointF(float(x), float(y)) for x, y in zip(self.editing_polygon["segmentation"][0::2], self.editing_polygon["segmentation"][1::2])] color = self.class_colors.get(self.editing_polygon["category_name"], QColor(Qt.white)) fill_color = QColor(color) fill_color.setAlphaF(self.fill_opacity) painter.setPen(QPen(color, 2 / self.zoom_factor, Qt.SolidLine)) painter.setBrush(QBrush(fill_color)) painter.drawPolygon(QPolygonF(points)) # Changed QPolygon to QPolygonF - Sreeni for i, point in enumerate(points): if i == self.hover_point_index: painter.setBrush(QColor(255, 0, 0)) else: painter.setBrush(QColor(0, 255, 0)) painter.drawEllipse(point, 5 / self.zoom_factor, 5 / self.zoom_factor) painter.restore() def calculate_centroid(self, points): """Calculate the centroid of a polygon.""" if not points: return None x_coords = [point.x() for point in points] y_coords = [point.y() for point in points] centroid_x = sum(x_coords) / len(points) centroid_y = sum(y_coords) / len(points) return QPointF(centroid_x, centroid_y) def set_zoom(self, zoom_factor): """Set the zoom factor and update the display.""" self.zoom_factor = zoom_factor self.update_scaled_pixmap() self.update() def wheelEvent(self, event: QWheelEvent): if event.modifiers() == Qt.ControlModifier: delta = event.angleDelta().y() if delta > 0: self.main_window.zoom_in() else: self.main_window.zoom_out() event.accept() else: super().wheelEvent(event) def mousePressEvent(self, event: QMouseEvent): if not self.original_pixmap: return if event.modifiers() == Qt.ControlModifier and event.button() == Qt.LeftButton: self.pan_start_pos = event.pos() self.setCursor(Qt.ClosedHandCursor) event.accept() else: pos = self.get_image_coordinates(event.pos()) if event.button() == Qt.LeftButton: if self.sam_magic_wand_active: self.sam_bbox = [pos[0], pos[1], pos[0], pos[1]] self.drawing_sam_bbox = True elif self.editing_polygon: self.handle_editing_click(pos, event) elif self.current_tool == "polygon": if not self.drawing_polygon: self.drawing_polygon = True self.current_annotation = [] self.current_annotation.append(pos) elif self.current_tool == "rectangle": self.start_point = pos self.end_point = pos self.drawing_rectangle = True self.current_rectangle = None elif self.current_tool == "paint_brush": self.start_painting(pos) elif self.current_tool == "eraser": self.start_erasing(pos) self.update() def mouseMoveEvent(self, event: QMouseEvent): if not self.original_pixmap: return self.cursor_pos = self.get_image_coordinates(event.pos()) if event.modifiers() == Qt.ControlModifier and event.buttons() == Qt.LeftButton: if self.pan_start_pos: delta = event.pos() - self.pan_start_pos scrollbar_h = self.main_window.scroll_area.horizontalScrollBar() scrollbar_v = self.main_window.scroll_area.verticalScrollBar() scrollbar_h.setValue(scrollbar_h.value() - delta.x()) scrollbar_v.setValue(scrollbar_v.value() - delta.y()) self.pan_start_pos = event.pos() event.accept() else: pos = self.cursor_pos if self.sam_magic_wand_active and self.drawing_sam_bbox: if self.sam_bbox is not None: self.sam_bbox[2] = pos[0] self.sam_bbox[3] = pos[1] elif self.editing_polygon: self.handle_editing_move(pos) elif self.current_tool == "polygon" and self.current_annotation: self.temp_point = pos elif self.current_tool == "rectangle" and self.drawing_rectangle: self.end_point = pos self.current_rectangle = self.get_rectangle_from_points() elif self.current_tool == "paint_brush" and event.buttons() == Qt.LeftButton: self.continue_painting(pos) elif self.current_tool == "eraser" and event.buttons() == Qt.LeftButton: self.continue_erasing(pos) self.update() def mouseReleaseEvent(self, event: QMouseEvent): if not self.original_pixmap: return if event.modifiers() == Qt.ControlModifier and event.button() == Qt.LeftButton: self.pan_start_pos = None self.setCursor(Qt.ArrowCursor) event.accept() else: pos = self.get_image_coordinates(event.pos()) if event.button() == Qt.LeftButton: if self.sam_magic_wand_active and self.drawing_sam_bbox: if self.sam_bbox is not None: self.sam_bbox[2] = pos[0] self.sam_bbox[3] = pos[1] self.drawing_sam_bbox = False self.main_window.apply_sam_prediction() elif self.editing_polygon: self.editing_point_index = None elif self.current_tool == "rectangle" and self.drawing_rectangle: self.drawing_rectangle = False if self.current_rectangle: self.main_window.finish_rectangle() elif self.current_tool == "paint_brush": self.finish_painting() elif self.current_tool == "eraser": self.finish_erasing() self.update() def mouseDoubleClickEvent(self, event): if not self.pixmap(): return pos = self.get_image_coordinates(event.pos()) if event.button() == Qt.LeftButton: if self.drawing_polygon and len(self.current_annotation) > 2: self.finish_polygon() else: self.clear_current_annotation() annotation = self.start_polygon_edit(pos) if annotation: self.main_window.select_annotation_in_list(annotation) self.update() def get_image_coordinates(self, pos): if not self.scaled_pixmap: return (0, 0) x = (pos.x() - self.offset_x) / self.zoom_factor y = (pos.y() - self.offset_y) / self.zoom_factor return (int(x), int(y)) def keyPressEvent(self, event: QKeyEvent): if event.key() == Qt.Key_Return or event.key() == Qt.Key_Enter: if self.temp_annotations: self.accept_temp_annotations() elif self.temp_sam_prediction: self.main_window.accept_sam_prediction() elif self.editing_polygon: self.editing_polygon = None self.editing_point_index = None self.hover_point_index = None self.main_window.enable_tools() self.main_window.update_annotation_list() elif self.current_tool == "polygon" and self.drawing_polygon: self.finish_polygon() elif self.current_tool == "paint_brush": self.commit_paint_annotation() elif self.current_tool == "eraser": self.commit_eraser_changes() else: self.finish_current_annotation() elif event.key() == Qt.Key_Escape: if self.temp_annotations: self.discard_temp_annotations() elif self.sam_magic_wand_active: self.sam_bbox = None self.clear_temp_sam_prediction() elif self.editing_polygon: self.editing_polygon = None self.editing_point_index = None self.hover_point_index = None self.main_window.enable_tools() elif self.current_tool == "paint_brush": self.discard_paint_annotation() elif self.current_tool == "eraser": self.discard_eraser_changes() else: self.cancel_current_annotation() elif event.key() == Qt.Key_Delete: if self.editing_polygon: self.main_window.delete_selected_annotations() self.editing_polygon = None self.editing_point_index = None self.hover_point_index = None self.main_window.enable_tools() self.update() elif event.key() == Qt.Key_Minus: if self.current_tool == "paint_brush": self.main_window.paint_brush_size = max(1, self.main_window.paint_brush_size - 1) print(f"Paint brush size: {self.main_window.paint_brush_size}") elif self.current_tool == "eraser": self.main_window.eraser_size = max(1, self.main_window.eraser_size - 1) print(f"Eraser size: {self.main_window.eraser_size}") elif event.key() == Qt.Key_Equal: if self.current_tool == "paint_brush": self.main_window.paint_brush_size += 1 print(f"Paint brush size: {self.main_window.paint_brush_size}") elif self.current_tool == "eraser": self.main_window.eraser_size += 1 print(f"Eraser size: {self.main_window.eraser_size}") self.update() def cancel_current_annotation(self): """Cancel the current annotation being created.""" if self.current_tool == "polygon" and self.current_annotation: self.current_annotation = [] self.temp_point = None self.drawing_polygon = False self.update() def finish_current_annotation(self): """Finish the current annotation being created.""" if self.current_tool == "polygon" and len(self.current_annotation) > 2: if self.main_window: self.main_window.finish_polygon() def finish_polygon(self): """Finish the current polygon annotation.""" if self.drawing_polygon and len(self.current_annotation) > 2: self.drawing_polygon = False if self.main_window: self.main_window.finish_polygon() def start_polygon_edit(self, pos): for class_name, annotations in self.annotations.items(): for annotation in annotations: if "segmentation" in annotation: points = [QPoint(int(x), int(y)) for x, y in zip(annotation["segmentation"][0::2], annotation["segmentation"][1::2])] if self.point_in_polygon(pos, points): self.editing_polygon = annotation self.current_tool = None self.main_window.disable_tools() self.main_window.reset_tool_buttons() return annotation return None def handle_editing_click(self, pos, event): """Handle clicks during polygon editing.""" points = [QPoint(int(x), int(y)) for x, y in zip(self.editing_polygon["segmentation"][0::2], self.editing_polygon["segmentation"][1::2])] for i, point in enumerate(points): if self.distance(pos, point) < 10 / self.zoom_factor: if event.modifiers() & Qt.ShiftModifier: # Delete point del self.editing_polygon["segmentation"][i*2:i*2+2] else: # Start moving point self.editing_point_index = i return # Add new point for i in range(len(points)): if self.point_on_line(pos, points[i], points[(i+1) % len(points)]): self.editing_polygon["segmentation"][i*2+2:i*2+2] = [pos[0], pos[1]] self.editing_point_index = i + 1 return def handle_editing_move(self, pos): """Handle mouse movement during polygon editing.""" points = [QPoint(int(x), int(y)) for x, y in zip(self.editing_polygon["segmentation"][0::2], self.editing_polygon["segmentation"][1::2])] self.hover_point_index = None for i, point in enumerate(points): if self.distance(pos, point) < 10 / self.zoom_factor: self.hover_point_index = i break if self.editing_point_index is not None: self.editing_polygon["segmentation"][self.editing_point_index*2] = pos[0] self.editing_polygon["segmentation"][self.editing_point_index*2+1] = pos[1] def exit_editing_mode(self): self.editing_polygon = None self.editing_point_index = None self.hover_point_index = None self.update() @staticmethod def point_in_polygon(point, polygon): """Check if a point is inside a polygon.""" n = len(polygon) inside = False p1x, p1y = polygon[0].x(), polygon[0].y() for i in range(n + 1): p2x, p2y = polygon[i % n].x(), polygon[i % n].y() if point[1] > min(p1y, p2y): if point[1] <= max(p1y, p2y): if point[0] <= max(p1x, p2x): if p1y != p2y: xinters = (point[1] - p1y) * (p2x - p1x) / (p2y - p1y) + p1x if p1x == p2x or point[0] <= xinters: inside = not inside p1x, p1y = p2x, p2y return inside @staticmethod def point_to_tuple(point): """Convert QPoint to tuple.""" if isinstance(point, QPoint): return (point.x(), point.y()) return point @staticmethod def distance(p1, p2): """Calculate distance between two points.""" p1 = ImageLabel.point_to_tuple(p1) p2 = ImageLabel.point_to_tuple(p2) return ((p1[0] - p2[0])**2 + (p1[1] - p2[1])**2)**0.5 @staticmethod def point_on_line(p, start, end): """Check if a point is on a line segment.""" p = ImageLabel.point_to_tuple(p) start = ImageLabel.point_to_tuple(start) end = ImageLabel.point_to_tuple(end) d1 = ImageLabel.distance(p, start) d2 = ImageLabel.distance(p, end) line_length = ImageLabel.distance(start, end) buffer = 0.1 # Adjust this value for more or less strict "on-line" detection return abs(d1 + d2 - line_length) < buffer ================================================ FILE: src/digitalsreeni_image_annotator/image_patcher.py ================================================ import os import numpy as np from PyQt5.QtWidgets import (QDialog, QVBoxLayout, QHBoxLayout, QPushButton, QLabel, QFileDialog, QSpinBox, QProgressBar, QMessageBox, QListWidget, QDialogButtonBox, QGridLayout, QComboBox, QApplication, QScrollArea, QWidget) from PyQt5.QtCore import Qt, QThread, pyqtSignal from PyQt5.QtCore import QTimer, QEventLoop from tifffile import TiffFile, imsave from PIL import Image import traceback class DimensionDialog(QDialog): def __init__(self, shape, file_name, parent=None): super().__init__(parent) self.shape = shape self.file_name = file_name self.initUI() def initUI(self): layout = QVBoxLayout() self.setLayout(layout) layout.addWidget(QLabel(f"File: {self.file_name}")) layout.addWidget(QLabel(f"Image shape: {self.shape}")) layout.addWidget(QLabel("Assign dimensions:")) grid_layout = QGridLayout() self.combos = [] dimensions = ['T', 'Z', 'C', 'H', 'W'] for i, dim in enumerate(self.shape): grid_layout.addWidget(QLabel(f"Dimension {i} (size {dim}):"), i, 0) combo = QComboBox() combo.addItems(dimensions) grid_layout.addWidget(combo, i, 1) self.combos.append(combo) layout.addLayout(grid_layout) self.button_box = QDialogButtonBox(QDialogButtonBox.Ok | QDialogButtonBox.Cancel) self.button_box.accepted.connect(self.accept) self.button_box.rejected.connect(self.reject) layout.addWidget(self.button_box) def get_dimensions(self): return [combo.currentText() for combo in self.combos] class PatchingThread(QThread): progress = pyqtSignal(int) error = pyqtSignal(str) finished = pyqtSignal() dimension_required = pyqtSignal(object, str) def __init__(self, input_files, output_dir, patch_size, overlap, dimensions): super().__init__() self.input_files = input_files self.output_dir = output_dir self.patch_size = patch_size self.overlap = overlap # Changed to tuple (to handle overlap_x, overlap_y independently) - Sreeni self.dimensions = dimensions def run(self): try: total_files = len(self.input_files) for i, file_path in enumerate(self.input_files): self.patch_image(file_path) self.progress.emit(int((i + 1) / total_files * 100)) self.finished.emit() except Exception as e: self.error.emit(str(e)) traceback.print_exc() def patch_image(self, file_path): file_name = os.path.basename(file_path) file_name_without_ext, file_extension = os.path.splitext(file_name) if file_extension.lower() in ['.tif', '.tiff']: with TiffFile(file_path) as tif: images = tif.asarray() if images.ndim > 2: if file_path not in self.dimensions: self.dimension_required.emit(images.shape, file_name) self.wait() dimensions = self.dimensions.get(file_path) if dimensions: if 'H' in dimensions and 'W' in dimensions: h_index = dimensions.index('H') w_index = dimensions.index('W') for idx in np.ndindex(images.shape[:h_index] + images.shape[h_index+2:]): slice_idx = idx[:h_index] + (slice(None), slice(None)) + idx[h_index:] image = images[slice_idx] slice_name = '_'.join([f'{dim}{i+1}' for dim, i in zip(dimensions, idx) if dim not in ['H', 'W']]) self.save_patches(image, f"{file_name_without_ext}_{slice_name}", file_extension) else: raise ValueError("You must assign both H and W dimensions.") else: raise ValueError("Dimensions were not properly assigned.") else: self.save_patches(images, file_name_without_ext, file_extension) else: with Image.open(file_path) as img: image = np.array(img) self.save_patches(image, file_name_without_ext, file_extension) def save_patches(self, image, base_name, extension): h, w = image.shape[:2] patch_h, patch_w = self.patch_size overlap_x, overlap_y = self.overlap for i in range(0, h - overlap_y, patch_h - overlap_y): for j in range(0, w - overlap_x, patch_w - overlap_x): if i + patch_h <= h and j + patch_w <= w: # Only save full-sized patches patch = image[i:i+patch_h, j:j+patch_w] patch_name = f"{base_name}_patch_{i}_{j}{extension}" output_path = os.path.join(self.output_dir, patch_name) if extension.lower() in ['.tif', '.tiff']: imsave(output_path, patch) else: Image.fromarray(patch).save(output_path) class ImagePatcherTool(QDialog): def __init__(self, parent=None): super().__init__(parent) self.setWindowModality(Qt.ApplicationModal) self.dimensions = {} self.input_files = [] self.output_dir = "" self.initUI() def initUI(self): layout = QVBoxLayout() self.setLayout(layout) # Input files selection input_layout = QHBoxLayout() self.input_label = QLabel("Input Files:") self.input_button = QPushButton("Select Files") self.input_button.clicked.connect(self.select_input_files) input_layout.addWidget(self.input_label) input_layout.addWidget(self.input_button) layout.addLayout(input_layout) # Output directory selection output_layout = QHBoxLayout() self.output_label = QLabel("Output Directory:") self.output_button = QPushButton("Select Directory") self.output_button.clicked.connect(self.select_output_directory) output_layout.addWidget(self.output_label) output_layout.addWidget(self.output_button) layout.addLayout(output_layout) # Patch size inputs patch_layout = QHBoxLayout() patch_layout.addWidget(QLabel("Patch Size (W x H):")) self.patch_w = QSpinBox() self.patch_w.setRange(1, 10000) self.patch_w.setValue(256) self.patch_h = QSpinBox() self.patch_h.setRange(1, 10000) self.patch_h.setValue(256) patch_layout.addWidget(self.patch_w) patch_layout.addWidget(self.patch_h) layout.addLayout(patch_layout) # Overlap inputs overlap_layout = QHBoxLayout() overlap_layout.addWidget(QLabel("Overlap (X, Y):")) self.overlap_x = QSpinBox() self.overlap_x.setRange(0, 1000) self.overlap_x.setValue(0) self.overlap_y = QSpinBox() self.overlap_y.setRange(0, 1000) self.overlap_y.setValue(0) overlap_layout.addWidget(self.overlap_x) overlap_layout.addWidget(self.overlap_y) layout.addLayout(overlap_layout) # Create a scroll area for patch info scroll_area = QScrollArea() scroll_area.setWidgetResizable(True) scroll_area.setMinimumHeight(200) # Set a minimum height for the scroll area # Create a widget to hold the patch info label self.patch_info_container = QWidget() patch_info_layout = QVBoxLayout(self.patch_info_container) # Add the patch info label to the container self.patch_info_label = QLabel() self.patch_info_label.setAlignment(Qt.AlignLeft | Qt.AlignTop) patch_info_layout.addWidget(self.patch_info_label) # Set the container as the scroll area's widget scroll_area.setWidget(self.patch_info_container) # Add the scroll area to the main layout layout.addWidget(scroll_area) # Start button self.start_button = QPushButton("Start Patching") self.start_button.clicked.connect(self.start_patching) layout.addWidget(self.start_button) # Progress bar self.progress_bar = QProgressBar() layout.addWidget(self.progress_bar) self.setWindowTitle('Image Patcher Tool') self.setMinimumWidth(500) # Set a minimum width for the dialog self.setMinimumHeight(600) # Set a minimum height for the dialog # Connect value changed signals self.patch_w.valueChanged.connect(self.update_patch_info) self.patch_h.valueChanged.connect(self.update_patch_info) self.overlap_x.valueChanged.connect(self.update_patch_info) self.overlap_y.valueChanged.connect(self.update_patch_info) def select_input_files(self): file_dialog = QFileDialog() self.input_files, _ = file_dialog.getOpenFileNames(self, "Select Input Files", "", "Image Files (*.png *.jpg *.bmp *.tif *.tiff)") self.input_label.setText(f"Input Files: {len(self.input_files)} selected") QApplication.processEvents() self.process_tiff_files() self.update_patch_info() def process_tiff_files(self): for file_path in self.input_files: if file_path.lower().endswith(('.tif', '.tiff')): self.check_tiff_dimensions(file_path) QApplication.processEvents() def check_tiff_dimensions(self, file_path): with TiffFile(file_path) as tif: images = tif.asarray() if images.ndim > 2: file_name = os.path.basename(file_path) dialog = DimensionDialog(images.shape, file_name, self) dialog.setWindowModality(Qt.ApplicationModal) result = dialog.exec_() if result == QDialog.Accepted: dimensions = dialog.get_dimensions() if 'H' in dimensions and 'W' in dimensions: self.dimensions[file_path] = dimensions else: QMessageBox.warning(self, "Invalid Dimensions", f"You must assign both H and W dimensions for {file_name}.") QApplication.processEvents() def select_output_directory(self): file_dialog = QFileDialog() self.output_dir = file_dialog.getExistingDirectory(self, "Select Output Directory") dir_name = os.path.basename(self.output_dir) if self.output_dir else "" self.output_label.setText(f"Output Directory: {dir_name}") QApplication.processEvents() self.update_patch_info() def start_patching(self): if not self.input_files: QMessageBox.warning(self, "No Input Files", "Please select input files.") return if not self.output_dir: QMessageBox.warning(self, "No Output Directory", "Please select an output directory.") return patch_size = (self.patch_h.value(), self.patch_w.value()) overlap = (self.overlap_x.value(), self.overlap_y.value()) self.patching_thread = PatchingThread(self.input_files, self.output_dir, patch_size, overlap, self.dimensions) self.patching_thread.progress.connect(self.update_progress) self.patching_thread.error.connect(self.show_error) self.patching_thread.finished.connect(self.patching_finished) self.patching_thread.dimension_required.connect(self.get_dimensions) self.patching_thread.start() self.start_button.setEnabled(False) def get_dimensions(self, shape, file_name): dialog = DimensionDialog(shape, file_name, self) dialog.setWindowModality(Qt.ApplicationModal) result = dialog.exec_() if result == QDialog.Accepted: dimensions = dialog.get_dimensions() if 'H' in dimensions and 'W' in dimensions: self.dimensions[file_name] = dimensions else: QMessageBox.warning(self, "Invalid Dimensions", f"You must assign both H and W dimensions for {file_name}.") QApplication.processEvents() self.patching_thread.wake() def get_patch_info(self): patch_info = {} patch_w = self.patch_w.value() patch_h = self.patch_h.value() overlap_x = self.overlap_x.value() overlap_y = self.overlap_y.value() for file_path in self.input_files: file_name = os.path.basename(file_path) if file_path.lower().endswith(('.tif', '.tiff')): with TiffFile(file_path) as tif: images = tif.asarray() if images.ndim > 2: dimensions = self.dimensions.get(file_path) if dimensions: h_index = dimensions.index('H') w_index = dimensions.index('W') h, w = images.shape[h_index], images.shape[w_index] else: h, w = images.shape[-2], images.shape[-1] else: h, w = images.shape else: with Image.open(file_path) as img: w, h = img.size patches_x = (w - overlap_x) // (patch_w - overlap_x) patches_y = (h - overlap_y) // (patch_h - overlap_y) leftover_x = w - (patches_x * (patch_w - overlap_x) + overlap_x) leftover_y = h - (patches_y * (patch_h - overlap_y) + overlap_y) patch_info[file_name] = { 'patches_x': patches_x, 'patches_y': patches_y, 'leftover_x': leftover_x, 'leftover_y': leftover_y } return patch_info def update_patch_info(self): if not self.input_files: self.patch_info_label.setText("No input files selected") return patch_info = self.get_patch_info() if patch_info: info_text = "Patch Information:
" for file_name, info in patch_info.items(): info_text += f"File: {file_name}
"
info_text += f"Patches: X: {info['patches_x']}, Y: {info['patches_y']}
"
info_text += f"Leftover pixels: X: {info['leftover_x']}, Y: {info['leftover_y']}
{key}:{value}
") else: formatted_stats.append(f"{line}
") stats_label = QLabel("".join(formatted_stats)) stats_label.setTextFormat(Qt.RichText) stats_label.setWordWrap(True) scroll_layout.addWidget(stats_label) scroll_area.setWidget(scroll_content) layout.addWidget(scroll_area) # Project notes layout.addWidget(bold_label("Project Notes:")) self.notes_edit = QTextEdit() self.notes_edit.setPlainText(getattr(self.parent, 'project_notes', '')) layout.addWidget(self.notes_edit) # Buttons button_box = QDialogButtonBox(QDialogButtonBox.Ok | QDialogButtonBox.Cancel) button_box.accepted.connect(self.accept) button_box.rejected.connect(self.reject) layout.addWidget(button_box) def get_notes(self): return self.notes_edit.toPlainText() def were_changes_made(self): return self.get_notes() != self.original_notes ================================================ FILE: src/digitalsreeni_image_annotator/project_search.py ================================================ from PyQt5.QtWidgets import (QDialog, QVBoxLayout, QHBoxLayout, QLineEdit, QPushButton, QDateEdit, QLabel, QListWidget, QDialogButtonBox, QFormLayout, QFileDialog, QMessageBox) from PyQt5.QtCore import Qt, QDate import os import json from datetime import datetime class ProjectSearchDialog(QDialog): def __init__(self, parent=None): super().__init__(parent) self.parent = parent self.setWindowTitle("Search Projects") self.setModal(True) self.setMinimumSize(600, 400) self.search_directory = "" self.setup_ui() def setup_ui(self): layout = QVBoxLayout(self) # Search criteria form_layout = QFormLayout() self.keyword_edit = QLineEdit() self.keyword_edit.setPlaceholderText("Enter search query (e.g., monkey AND dog AND (project_animals OR project_zoo))") form_layout.addRow("Search Query:", self.keyword_edit) self.start_date = QDateEdit() self.start_date.setCalendarPopup(True) self.start_date.setDate(QDate.currentDate().addYears(-1)) form_layout.addRow("Start Date:", self.start_date) self.end_date = QDateEdit() self.end_date.setCalendarPopup(True) self.end_date.setDate(QDate.currentDate()) form_layout.addRow("End Date:", self.end_date) layout.addLayout(form_layout) # Directory selection dir_layout = QHBoxLayout() self.dir_edit = QLineEdit() dir_layout.addWidget(self.dir_edit) dir_button = QPushButton("Browse") dir_button.clicked.connect(self.browse_directory) dir_layout.addWidget(dir_button) layout.addLayout(dir_layout) # Search button search_button = QPushButton("Search") search_button.clicked.connect(self.perform_search) layout.addWidget(search_button) # Results list self.results_list = QListWidget() self.results_list.itemDoubleClicked.connect(self.open_selected_project) layout.addWidget(self.results_list) # Buttons button_box = QDialogButtonBox(QDialogButtonBox.Close) button_box.rejected.connect(self.reject) layout.addWidget(button_box) def browse_directory(self): directory = QFileDialog.getExistingDirectory(self, "Select Directory to Search") if directory: self.search_directory = directory self.dir_edit.setText(directory) def perform_search(self): if not self.search_directory: QMessageBox.warning(self, "No Directory", "Please select a directory to search.") return query = self.keyword_edit.text() start_date = self.start_date.date().toPyDate() end_date = self.end_date.date().toPyDate() self.results_list.clear() for root, dirs, files in os.walk(self.search_directory): for filename in files: if filename.endswith('.iap'): project_path = os.path.join(root, filename) try: with open(project_path, 'r') as f: project_data = json.load(f) if self.project_matches(project_data, query, start_date, end_date): self.results_list.addItem(project_path) except Exception as e: print(f"Error reading project file {filename}: {str(e)}") if self.results_list.count() == 0: QMessageBox.information(self, "Search Results", "No matching projects found.") else: QMessageBox.information(self, "Search Results", f"{self.results_list.count()} matching projects found.") def project_matches(self, project_data, query, start_date, end_date): # Check date range creation_date = project_data.get('creation_date', '') if creation_date: try: creation_date = datetime.fromisoformat(creation_date).date() if creation_date < start_date or creation_date > end_date: return False except ValueError: print(f"Invalid date format in project: {creation_date}") if not query: return True return self.evaluate_query(query.lower(), project_data) def term_matches(self, term, project_data): # Search in project name if term in os.path.basename(project_data.get('current_project_file', '')).lower(): return True # Search in classes if any(term in class_info['name'].lower() for class_info in project_data.get('classes', [])): return True # Search in image names if any(term in img['file_name'].lower() for img in project_data.get('images', [])): return True # Search in project notes if term in project_data.get('notes', '').lower(): return True return False def evaluate_query(self, query, project_data): tokens = self.tokenize_query(query) return self.evaluate_tokens(tokens, project_data) def tokenize_query(self, query): tokens = [] current_token = "" for char in query: if char in '()': if current_token: tokens.append(current_token) current_token = "" tokens.append(char) elif char.isspace(): if current_token: tokens.append(current_token) current_token = "" else: current_token += char if current_token: tokens.append(current_token) return tokens def evaluate_tokens(self, tokens, project_data): def evaluate_expression(): nonlocal i result = True current_op = 'and' while i < len(tokens): if tokens[i] == '(': i += 1 sub_result = evaluate_expression() if current_op == 'and': result = result and sub_result else: result = result or sub_result elif tokens[i] == ')': return result elif tokens[i].lower() in ['and', 'or']: current_op = tokens[i].lower() else: term_result = self.term_matches(tokens[i], project_data) if current_op == 'and': result = result and term_result else: result = result or term_result i += 1 return result i = 0 return evaluate_expression() def keyword_matches(self, keyword, project_data): # Search in project name if keyword in os.path.basename(project_data.get('current_project_file', '')).lower().split(): return True # Search in classes if any(keyword in class_info['name'].lower().split() for class_info in project_data.get('classes', [])): return True # Search in image names if any(keyword in img['file_name'].lower().split() for img in project_data.get('images', [])): return True # Search in project notes if keyword in project_data.get('notes', '').lower().split(): return True # Search in creation date and last modified date if keyword in project_data.get('creation_date', '').lower().split() or keyword in project_data.get('last_modified', '').lower().split(): return True return False def open_selected_project(self, item): project_file = item.text() self.parent.open_specific_project(project_file) self.accept() def show_project_search(parent): dialog = ProjectSearchDialog(parent) dialog.exec_() ================================================ FILE: src/digitalsreeni_image_annotator/sam_utils.py ================================================ import numpy as np from PyQt5.QtGui import QImage, QColor from ultralytics import SAM class SAMUtils: def __init__(self): self.sam_models = { "SAM 2 tiny": "sam2_t.pt", "SAM 2 small": "sam2_s.pt", "SAM 2 base": "sam2_b.pt", "SAM 2 large": "sam2_l.pt", "SAM 2.1 tiny": "sam2.1_t.pt", "SAM 2.1 small": "sam2.1_s.pt", "SAM 2.1 base": "sam2.1_b.pt", "SAM 2.1 large": "sam2.1_l.pt", } self.current_sam_model = None self.sam_model = None def change_sam_model(self, model_name): if model_name != "Pick a SAM Model": self.current_sam_model = model_name self.sam_model = SAM(self.sam_models[self.current_sam_model]) print(f"Changed SAM model to: {model_name}") else: self.current_sam_model = None self.sam_model = None print("SAM model unset") def qimage_to_numpy(self, qimage): width = qimage.width() height = qimage.height() fmt = qimage.format() if fmt == QImage.Format_Grayscale16: buffer = qimage.constBits().asarray(height * width * 2) image = np.frombuffer(buffer, dtype=np.uint16).reshape((height, width)) image_8bit = self.normalize_16bit_to_8bit(image) return np.stack((image_8bit,) * 3, axis=-1) elif fmt == QImage.Format_RGB16: buffer = qimage.constBits().asarray(height * width * 2) image = np.frombuffer(buffer, dtype=np.uint16).reshape((height, width)) image_8bit = self.normalize_16bit_to_8bit(image) return np.stack((image_8bit,) * 3, axis=-1) elif fmt == QImage.Format_Grayscale8: buffer = qimage.constBits().asarray(height * width) image = np.frombuffer(buffer, dtype=np.uint8).reshape((height, width)) return np.stack((image,) * 3, axis=-1) elif fmt in [QImage.Format_RGB32, QImage.Format_ARGB32, QImage.Format_ARGB32_Premultiplied]: buffer = qimage.constBits().asarray(height * width * 4) image = np.frombuffer(buffer, dtype=np.uint8).reshape((height, width, 4)) return image[:, :, :3] elif fmt == QImage.Format_RGB888: buffer = qimage.constBits().asarray(height * width * 3) image = np.frombuffer(buffer, dtype=np.uint8).reshape((height, width, 3)) return image elif fmt == QImage.Format_Indexed8: buffer = qimage.constBits().asarray(height * width) image = np.frombuffer(buffer, dtype=np.uint8).reshape((height, width)) color_table = qimage.colorTable() rgb_image = np.zeros((height, width, 3), dtype=np.uint8) for y in range(height): for x in range(width): rgb_image[y, x] = QColor(color_table[image[y, x]]).getRgb()[:3] return rgb_image else: converted_image = qimage.convertToFormat(QImage.Format_RGB32) buffer = converted_image.constBits().asarray(height * width * 4) image = np.frombuffer(buffer, dtype=np.uint8).reshape((height, width, 4)) return image[:, :, :3] def normalize_16bit_to_8bit(self, array): return ((array - array.min()) / (array.max() - array.min()) * 255).astype(np.uint8) def apply_sam_prediction(self, image, bbox): try: image_np = self.qimage_to_numpy(image) results = self.sam_model(image_np, bboxes=[bbox]) mask = results[0].masks.data[0].cpu().numpy() if mask is not None: print(f"Mask shape: {mask.shape}, Mask sum: {mask.sum()}") contours = self.mask_to_polygon(mask) print(f"Contours generated: {len(contours)} contour(s)") if not contours: print("No valid contours found") return None prediction = { "segmentation": contours[0], "score": float(results[0].boxes.conf[0]) } return prediction else: print("Failed to generate mask") return None except Exception as e: print(f"Error in applying SAM prediction: {str(e)}") import traceback traceback.print_exc() return None def mask_to_polygon(self, mask): import cv2 contours, _ = cv2.findContours((mask > 0).astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) polygons = [] for contour in contours: if cv2.contourArea(contour) > 10: polygon = contour.flatten().tolist() if len(polygon) >= 6: polygons.append(polygon) print(f"Generated {len(polygons)} valid polygons") return polygons ================================================ FILE: src/digitalsreeni_image_annotator/slice_registration.py ================================================ from PyQt5.QtWidgets import (QDialog, QVBoxLayout, QHBoxLayout, QPushButton, QFileDialog, QLabel, QComboBox, QMessageBox, QProgressDialog, QRadioButton, QButtonGroup, QSpinBox, QApplication, QGroupBox, QDoubleSpinBox) from PyQt5.QtCore import Qt from pystackreg import StackReg from skimage import io import tifffile from PIL import Image import numpy as np import os class SliceRegistrationTool(QDialog): def __init__(self, parent=None): super().__init__(parent) self.setWindowTitle("Slice Registration") self.setGeometry(100, 100, 600, 400) self.setWindowFlags(self.windowFlags() | Qt.Window) self.setWindowModality(Qt.ApplicationModal) # Add modal behavior # Initialize variables first self.input_path = "" self.output_directory = "" self.initUI() def initUI(self): layout = QVBoxLayout() layout.setSpacing(10) # Add consistent spacing # Input selection input_group = QGroupBox("Input Selection") input_layout = QVBoxLayout() self.dir_radio = QRadioButton("Directory of Image Files") self.stack_radio = QRadioButton("TIFF Stack") input_group = QButtonGroup(self) input_group.addButton(self.dir_radio) input_group.addButton(self.stack_radio) input_layout.addWidget(self.dir_radio) input_layout.addWidget(self.stack_radio) self.dir_radio.setChecked(True) # Input/Output file selection with labels self.input_label = QLabel("No input selected") self.output_label = QLabel("No output directory selected") file_select_layout = QVBoxLayout() input_file_layout = QHBoxLayout() self.select_input_btn = QPushButton("Select Input") self.select_input_btn.clicked.connect(self.select_input) input_file_layout.addWidget(self.select_input_btn) input_file_layout.addWidget(self.input_label) output_file_layout = QHBoxLayout() self.select_output_btn = QPushButton("Select Output Directory") self.select_output_btn.clicked.connect(self.select_output) output_file_layout.addWidget(self.select_output_btn) output_file_layout.addWidget(self.output_label) file_select_layout.addLayout(input_file_layout) file_select_layout.addLayout(output_file_layout) input_layout.addLayout(file_select_layout) layout.addLayout(input_layout) # Transform type transform_group = QGroupBox("Transformation Settings") transform_layout = QVBoxLayout() transform_combo_layout = QHBoxLayout() transform_combo_layout.addWidget(QLabel("Type:")) self.transform_combo = QComboBox() self.transform_combo.addItems([ "Translation (X-Y Translation Only)", "Rigid Body (Translation + Rotation)", "Scaled Rotation (Translation + Rotation + Scaling)", "Affine (Translation + Rotation + Scaling + Shearing)", "Bilinear (Non-linear; Does not preserve straight lines)" ]) transform_combo_layout.addWidget(self.transform_combo) transform_layout.addLayout(transform_combo_layout) transform_group.setLayout(transform_layout) layout.addWidget(transform_group) # Reference type ref_group = QGroupBox("Reference Settings") ref_layout = QVBoxLayout() ref_combo_layout = QHBoxLayout() ref_combo_layout.addWidget(QLabel("Reference:")) self.ref_combo = QComboBox() self.ref_combo.addItems([ "Previous Frame", "First Frame", "Mean of All Frames", "Mean of First N Frames", "Mean of First N Frames + Moving Average" ]) ref_combo_layout.addWidget(self.ref_combo) ref_layout.addLayout(ref_combo_layout) # N frames settings n_frames_layout = QHBoxLayout() n_frames_layout.addWidget(QLabel("N Frames:")) self.n_frames_spin = QSpinBox() self.n_frames_spin.setRange(1, 100) self.n_frames_spin.setValue(10) self.n_frames_spin.setEnabled(False) n_frames_layout.addWidget(self.n_frames_spin) ref_layout.addLayout(n_frames_layout) # Moving average settings moving_avg_layout = QHBoxLayout() moving_avg_layout.addWidget(QLabel("Moving Average Window:")) self.moving_avg_spin = QSpinBox() self.moving_avg_spin.setRange(1, 100) self.moving_avg_spin.setValue(10) self.moving_avg_spin.setEnabled(False) moving_avg_layout.addWidget(self.moving_avg_spin) ref_layout.addLayout(moving_avg_layout) ref_group.setLayout(ref_layout) layout.addWidget(ref_group) # Connect reference combo box self.ref_combo.currentTextChanged.connect(self.on_ref_changed) # Add spacing group spacing_group = QGroupBox("Pixel/Voxel Size") spacing_layout = QVBoxLayout() # XY pixel size xy_size_layout = QHBoxLayout() xy_size_layout.addWidget(QLabel("XY Pixel Size:")) self.xy_size_value = QDoubleSpinBox() self.xy_size_value.setRange(0.001, 1000.0) self.xy_size_value.setValue(1.0) self.xy_size_value.setDecimals(3) xy_size_layout.addWidget(self.xy_size_value) # Z spacing z_size_layout = QHBoxLayout() z_size_layout.addWidget(QLabel("Z Spacing:")) self.z_size_value = QDoubleSpinBox() self.z_size_value.setRange(0.001, 1000.0) self.z_size_value.setValue(1.0) self.z_size_value.setDecimals(3) z_size_layout.addWidget(self.z_size_value) # Unit selector unit_layout = QHBoxLayout() unit_layout.addWidget(QLabel("Unit:")) self.size_unit = QComboBox() self.size_unit.addItems(["nm", "µm", "mm"]) self.size_unit.setCurrentText("µm") unit_layout.addWidget(self.size_unit) spacing_layout.addLayout(xy_size_layout) spacing_layout.addLayout(z_size_layout) spacing_layout.addLayout(unit_layout) spacing_group.setLayout(spacing_layout) layout.addWidget(spacing_group) # Register button self.register_btn = QPushButton("Register") self.register_btn.clicked.connect(self.register_slices) layout.addWidget(self.register_btn) self.setLayout(layout) def on_ref_changed(self, text): uses_n_frames = text in ["Mean of First N Frames", "Mean of First N Frames + Moving Average"] self.n_frames_spin.setEnabled(uses_n_frames) self.moving_avg_spin.setEnabled(text == "Mean of First N Frames + Moving Average") QApplication.processEvents() # Ensure UI updates def on_transform_changed(self, text): if text == "Bilinear" and self.ref_combo.currentText() == "Previous": QMessageBox.warning(self, "Warning", "Bilinear transformation cannot be used with 'Previous' reference. " "Please select a different reference type.") self.transform_combo.setCurrentText("Rigid Body") def select_input(self): try: if self.dir_radio.isChecked(): path = QFileDialog.getExistingDirectory( self, "Select Directory with Images", "", QFileDialog.ShowDirsOnly | QFileDialog.DontResolveSymlinks ) else: path, _ = QFileDialog.getOpenFileName( self, "Select TIFF Stack", "", "TIFF Files (*.tif *.tiff)", options=QFileDialog.Options() ) if path: self.input_path = path self.input_label.setText(f"Selected: {os.path.basename(path)}") self.input_label.setToolTip(path) QApplication.processEvents() except Exception as e: QMessageBox.critical(self, "Error", f"Error selecting input: {str(e)}") def select_output(self): try: directory = QFileDialog.getExistingDirectory( self, "Select Output Directory", "", QFileDialog.ShowDirsOnly | QFileDialog.DontResolveSymlinks ) if directory: self.output_directory = directory self.output_label.setText(f"Selected: {os.path.basename(directory)}") self.output_label.setToolTip(directory) QApplication.processEvents() except Exception as e: QMessageBox.critical(self, "Error", f"Error selecting output directory: {str(e)}") def register_slices(self): if not self.input_path or not self.output_directory: QMessageBox.warning(self, "Error", "Please select both input and output paths") return try: progress = QProgressDialog(self) progress.setWindowTitle("Registration Progress") progress.setLabelText("Loading images...") progress.setMinimum(0) progress.setMaximum(100) progress.setWindowModality(Qt.WindowModal) progress.setMinimumWidth(400) progress.show() QApplication.processEvents() # Load images using scikit-image's imread if self.stack_radio.isChecked(): progress.setLabelText("Loading TIFF stack...") img0 = io.imread(self.input_path) else: progress.setLabelText("Loading images from directory...") image_files = sorted([f for f in os.listdir(self.input_path) if f.lower().endswith(('.png', '.jpg', '.jpeg', '.tif', '.tiff'))]) first_img = io.imread(os.path.join(self.input_path, image_files[0])) img0 = np.zeros((len(image_files), *first_img.shape), dtype=first_img.dtype) img0[0] = first_img for i, fname in enumerate(image_files[1:], 1): img0[i] = io.imread(os.path.join(self.input_path, fname)) # Store original properties original_dtype = img0.dtype print(f"Original image properties:") print(f"Dtype: {original_dtype}") print(f"Range: {img0.min()} - {img0.max()}") print(f"Shape: {img0.shape}") progress.setValue(30) progress.setLabelText("Performing registration...") QApplication.processEvents() # Set up StackReg with selected transformation transform_types = { "Translation (X-Y Translation Only)": StackReg.TRANSLATION, "Rigid Body (Translation + Rotation)": StackReg.RIGID_BODY, "Scaled Rotation (Translation + Rotation + Scaling)": StackReg.SCALED_ROTATION, "Affine (Translation + Rotation + Scaling + Shearing)": StackReg.AFFINE, "Bilinear (Non-linear; Does not preserve straight lines)": StackReg.BILINEAR } transform_type = transform_types[self.transform_combo.currentText()] sr = StackReg(transform_type) # Register images selected_ref = self.ref_combo.currentText() progress.setLabelText(f"Registering images using {selected_ref}...") progress.setValue(40) QApplication.processEvents() # Register and transform if selected_ref == "Previous Frame": out_registered = sr.register_transform_stack(img0, reference='previous') elif selected_ref == "First Frame": out_registered = sr.register_transform_stack(img0, reference='first') elif selected_ref == "Mean of All Frames": out_registered = sr.register_transform_stack(img0, reference='mean') elif selected_ref == "Mean of First N Frames": n_frames = self.n_frames_spin.value() out_registered = sr.register_transform_stack(img0, reference='first', n_frames=n_frames) elif selected_ref == "Mean of First N Frames + Moving Average": n_frames = self.n_frames_spin.value() moving_avg = self.moving_avg_spin.value() out_registered = sr.register_transform_stack(img0, reference='first', n_frames=n_frames, moving_average=moving_avg) progress.setValue(80) progress.setLabelText("Saving registered images...") QApplication.processEvents() # Convert back to original dtype without changing values out_registered = out_registered.astype(original_dtype) print(f"Output image properties:") print(f"Dtype: {out_registered.dtype}") print(f"Range: {out_registered.min()} - {out_registered.max()}") print(f"Shape: {out_registered.shape}") # Save output if self.stack_radio.isChecked(): output_name = os.path.splitext(os.path.basename(self.input_path))[0] else: output_name = "registered_stack" output_path = os.path.join(self.output_directory, f"{output_name}_registered.tif") # Get pixel sizes in micrometers (convert if necessary) xy_size = self.xy_size_value.value() z_size = self.z_size_value.value() unit = self.size_unit.currentText() # Convert to micrometers based on selected unit if unit == "nm": xy_size = xy_size / 1000 z_size = z_size / 1000 elif unit == "mm": xy_size = xy_size * 1000 z_size = z_size * 1000 # Save the stack tifffile.imwrite( output_path, out_registered, imagej=True, metadata={ 'axes': 'ZYX', 'spacing': z_size, # Z spacing in micrometers 'unit': 'um', 'finterval': xy_size # XY pixel size in micrometers }, resolution=(1.0/xy_size, 1.0/xy_size) # XY Resolution in pixels per unit ) progress.setValue(100) QApplication.processEvents() # Include both XY and Z size info in success message QMessageBox.information(self, "Success", f"Registration completed successfully!\n" f"Output saved to:\n{output_path}\n" f"XY Pixel size: {self.xy_size_value.value()} {unit}\n" f"Z Spacing: {self.z_size_value.value()} {unit}") except Exception as e: print(f"Error occurred: {str(e)}") import traceback traceback.print_exc() QMessageBox.critical(self, "Error", str(e)) def update_progress(self, progress_dialog, current_iteration, end_iteration): """Helper function to update progress during registration""" if end_iteration > 0: percent = int(40 + (current_iteration / end_iteration) * 40) # Scale to 40-80% range progress_dialog.setValue(percent) progress_dialog.setLabelText(f"Processing image {current_iteration}/{end_iteration}...") QApplication.processEvents() def load_images(self): print("Starting image loading...") try: if self.stack_radio.isChecked(): print(f"Loading TIFF stack from: {self.input_path}") # Explicitly use scikit-image's imread for TIFF stacks stack = io.imread(self.input_path) if stack.dtype != np.float32: stack = stack.astype(np.float32) print(f"Loaded TIFF stack shape: {stack.shape}") return stack else: # Load individual images valid_extensions = ('.png', '.jpg', '.jpeg', '.bmp', '.tif', '.tiff') images = [] files = sorted([f for f in os.listdir(self.input_path) if f.lower().endswith(valid_extensions)]) print(f"Found {len(files)} image files") if not files: raise ValueError("No valid image files found in directory") # Check first image size first_path = os.path.join(self.input_path, files[0]) print(f"Loading first image: {first_path}") first_img = np.array(Image.open(first_path)) ref_shape = first_img.shape images.append(first_img) print(f"First image shape: {ref_shape}") # Load remaining images and check sizes for f in files[1:]: img_path = os.path.join(self.input_path, f) print(f"Loading: {f}") img = np.array(Image.open(img_path)) if img.shape != ref_shape: raise ValueError(f"Image {f} has different dimensions from the first image") images.append(img) stack = np.stack(images) print(f"Final stack shape: {stack.shape}") return stack except Exception as e: print(f"Error in load_images: {str(e)}") raise def show_centered(self, parent): parent_geo = parent.geometry() self.move(parent_geo.center() - self.rect().center()) self.show() QApplication.processEvents() # Ensure window displays properly ================================================ FILE: src/digitalsreeni_image_annotator/snake_game.py ================================================ import sys import random from PyQt5.QtWidgets import QApplication, QWidget, QDesktopWidget, QMessageBox from PyQt5.QtGui import QPainter, QColor from PyQt5.QtCore import Qt, QTimer class SnakeGame(QWidget): def __init__(self): super().__init__() self.initUI() def initUI(self): self.setWindowTitle('Secret Snake Game') self.setFixedSize(600, 400) # Increased size self.center() self.snake = [(300, 200), (290, 200), (280, 200)] self.direction = 'RIGHT' self.food = self.place_food() self.score = 0 self.timer = QTimer(self) self.timer.timeout.connect(self.update_game) self.timer.start(100) self.setFocusPolicy(Qt.StrongFocus) self.show() def center(self): qr = self.frameGeometry() cp = QDesktopWidget().availableGeometry().center() qr.moveCenter(cp) self.move(qr.topLeft()) def paintEvent(self, event): painter = QPainter(self) painter.setRenderHint(QPainter.Antialiasing) # Draw snake painter.setBrush(QColor(0, 255, 0)) for segment in self.snake: painter.drawRect(segment[0], segment[1], 10, 10) # Draw food painter.setBrush(QColor(255, 0, 0)) painter.drawRect(self.food[0], self.food[1], 10, 10) # Draw score painter.setPen(QColor(0, 0, 0)) painter.drawText(10, 20, f"Score: {self.score}") def keyPressEvent(self, event): key = event.key() if key == Qt.Key_Left and self.direction != 'RIGHT': self.direction = 'LEFT' elif key == Qt.Key_Right and self.direction != 'LEFT': self.direction = 'RIGHT' elif key == Qt.Key_Up and self.direction != 'DOWN': self.direction = 'UP' elif key == Qt.Key_Down and self.direction != 'UP': self.direction = 'DOWN' elif key == Qt.Key_Escape: self.close() def update_game(self): head = self.snake[0] if self.direction == 'LEFT': new_head = (head[0] - 10, head[1]) elif self.direction == 'RIGHT': new_head = (head[0] + 10, head[1]) elif self.direction == 'UP': new_head = (head[0], head[1] - 10) else: # DOWN new_head = (head[0], head[1] + 10) # Check if snake hit the edge if (new_head[0] < 0 or new_head[0] >= 600 or new_head[1] < 0 or new_head[1] >= 400): self.game_over() return self.snake.insert(0, new_head) if new_head == self.food: self.score += 1 self.food = self.place_food() else: self.snake.pop() if new_head in self.snake[1:]: self.game_over() return self.update() def place_food(self): while True: x = random.randint(0, 59) * 10 y = random.randint(0, 39) * 10 if (x, y) not in self.snake: return (x, y) def game_over(self): self.timer.stop() QMessageBox.information(self, "Game Over", f"Your score: {self.score}") self.close() if __name__ == '__main__': app = QApplication(sys.argv) ex = SnakeGame() sys.exit(app.exec_()) ================================================ FILE: src/digitalsreeni_image_annotator/soft_dark_stylesheet.py ================================================ # soft_dark_stylesheet.py soft_dark_stylesheet = """ QWidget { background-color: #2F2F2F; color: #E0E0E0; font-family: Arial, sans-serif; } QMainWindow { background-color: #2A2A2A; } QPushButton { background-color: #4A4A4A; border: 1px solid #5E5E5E; padding: 5px 10px; border-radius: 3px; color: #E0E0E0; } QPushButton:hover { background-color: #545454; } QPushButton:pressed { background-color: #404040; } QPushButton:checked { background-color: #606060; border: 2px solid #808080; color: #FFFFFF; } QListWidget, QTreeWidget { background-color: #3A3A3A; border: 1px solid #4A4A4A; border-radius: 3px; color: #E0E0E0; } QListWidget::item, QTreeWidget::item { color: #E0E0E0; } QListWidget::item:selected, QTreeWidget::item:selected { background-color: #4A4A4A; color: #FFFFFF; /* Make selected items a bit brighter */ } QLabel { color: #E0E0E0; } QLabel.section-header { font-weight: bold; font-size: 14px; padding: 5px 0; color: #FFFFFF; /* Bright white color for better visibility in dark mode */ } QLineEdit, QTextEdit, QPlainTextEdit { background-color: #3A3A3A; border: 1px solid #4A4A4A; color: #E0E0E0; padding: 2px; border-radius: 3px; } QSlider::groove:horizontal { background: #4A4A4A; height: 8px; border-radius: 4px; } QSlider::handle:horizontal { background: #6A6A6A; width: 18px; margin-top: -5px; margin-bottom: -5px; border-radius: 9px; } QSlider::handle:horizontal:hover { background: #7A7A7A; } QScrollBar:vertical, QScrollBar:horizontal { background-color: #3A3A3A; width: 12px; height: 12px; } QScrollBar::handle:vertical, QScrollBar::handle:horizontal { background-color: #5A5A5A; border-radius: 6px; min-height: 20px; } QScrollBar::handle:vertical:hover, QScrollBar::handle:horizontal:hover { background-color: #6A6A6A; } QScrollBar::add-line, QScrollBar::sub-line { background: none; } QMenuBar { background-color: #2F2F2F; } QMenuBar::item { padding: 5px 10px; background-color: transparent; } QMenuBar::item:selected { background-color: #3A3A3A; } QMenu { background-color: #2F2F2F; border: 1px solid #3A3A3A; } QMenu::item { padding: 5px 20px 5px 20px; } QMenu::item:selected { background-color: #3A3A3A; } QToolTip { background-color: #2F2F2F; color: #E0E0E0; border: 1px solid #3A3A3A; } QStatusBar { background-color: #2A2A2A; color: #B0B0B0; } QListWidget::item { color: none; } """ ================================================ FILE: src/digitalsreeni_image_annotator/stack_interpolator.py ================================================ import os import numpy as np from PyQt5.QtWidgets import (QDialog, QVBoxLayout, QHBoxLayout, QPushButton, QFileDialog, QLabel, QComboBox, QMessageBox, QProgressDialog, QRadioButton, QButtonGroup, QGroupBox, QDoubleSpinBox, QApplication) from PyQt5.QtCore import Qt from scipy.interpolate import RegularGridInterpolator from skimage import io import tifffile class StackInterpolator(QDialog): def __init__(self, parent=None): super().__init__(parent) self.setWindowTitle("Stack Interpolator") self.setGeometry(100, 100, 600, 400) self.setWindowFlags(self.windowFlags() | Qt.Window) self.setWindowModality(Qt.ApplicationModal) # Added window modality # Initialize variables self.input_path = "" self.output_directory = "" self.initUI() def initUI(self): layout = QVBoxLayout() layout.setSpacing(10) # Add consistent spacing # Input selection input_group = QGroupBox("Input Selection") input_layout = QVBoxLayout() # Radio buttons for input type self.dir_radio = QRadioButton("Directory of Image Files") self.stack_radio = QRadioButton("TIFF Stack") input_group_buttons = QButtonGroup(self) input_group_buttons.addButton(self.dir_radio) input_group_buttons.addButton(self.stack_radio) input_layout.addWidget(self.dir_radio) input_layout.addWidget(self.stack_radio) self.dir_radio.setChecked(True) input_group.setLayout(input_layout) layout.addWidget(input_group) # Interpolation method method_group = QGroupBox("Interpolation Settings") method_layout = QVBoxLayout() method_combo_layout = QHBoxLayout() method_combo_layout.addWidget(QLabel("Method:")) self.method_combo = QComboBox() self.method_combo.addItems([ "linear", "nearest", "slinear", "cubic", "quintic", "pchip" ]) method_combo_layout.addWidget(self.method_combo) method_layout.addLayout(method_combo_layout) method_group.setLayout(method_layout) layout.addWidget(method_group) # Original dimensions group orig_group = QGroupBox("Original Dimensions") orig_layout = QVBoxLayout() orig_xy_layout = QHBoxLayout() orig_xy_layout.addWidget(QLabel("XY Pixel Size:")) self.orig_xy_size = QDoubleSpinBox() self.orig_xy_size.setRange(0.001, 1000.0) self.orig_xy_size.setValue(1.0) self.orig_xy_size.setDecimals(3) orig_xy_layout.addWidget(self.orig_xy_size) orig_z_layout = QHBoxLayout() orig_z_layout.addWidget(QLabel("Z Spacing:")) self.orig_z_size = QDoubleSpinBox() self.orig_z_size.setRange(0.001, 1000.0) self.orig_z_size.setValue(1.0) self.orig_z_size.setDecimals(3) orig_z_layout.addWidget(self.orig_z_size) orig_layout.addLayout(orig_xy_layout) orig_layout.addLayout(orig_z_layout) orig_group.setLayout(orig_layout) layout.addWidget(orig_group) # New dimensions group new_group = QGroupBox("New Dimensions") new_layout = QVBoxLayout() new_xy_layout = QHBoxLayout() new_xy_layout.addWidget(QLabel("XY Pixel Size:")) self.new_xy_size = QDoubleSpinBox() self.new_xy_size.setRange(0.001, 1000.0) self.new_xy_size.setValue(1.0) self.new_xy_size.setDecimals(3) new_xy_layout.addWidget(self.new_xy_size) new_z_layout = QHBoxLayout() new_z_layout.addWidget(QLabel("Z Spacing:")) self.new_z_size = QDoubleSpinBox() self.new_z_size.setRange(0.001, 1000.0) self.new_z_size.setValue(1.0) self.new_z_size.setDecimals(3) new_z_layout.addWidget(self.new_z_size) new_layout.addLayout(new_xy_layout) new_layout.addLayout(new_z_layout) new_group.setLayout(new_layout) layout.addWidget(new_group) # Units selector unit_group = QGroupBox("Unit Settings") unit_layout = QHBoxLayout() unit_layout.addWidget(QLabel("Unit:")) self.size_unit = QComboBox() self.size_unit.addItems(["nm", "µm", "mm"]) self.size_unit.setCurrentText("µm") unit_layout.addWidget(self.size_unit) unit_group.setLayout(unit_layout) layout.addWidget(unit_group) # Input/Output buttons button_group = QGroupBox("File Selection") button_layout = QVBoxLayout() # Input selection input_file_layout = QHBoxLayout() self.input_label = QLabel("No input selected") self.select_input_btn = QPushButton("Select Input") self.select_input_btn.clicked.connect(self.select_input) input_file_layout.addWidget(self.select_input_btn) input_file_layout.addWidget(self.input_label) button_layout.addLayout(input_file_layout) # Output selection output_file_layout = QHBoxLayout() self.output_label = QLabel("No output directory selected") self.select_output_btn = QPushButton("Select Output Directory") self.select_output_btn.clicked.connect(self.select_output) output_file_layout.addWidget(self.select_output_btn) output_file_layout.addWidget(self.output_label) button_layout.addLayout(output_file_layout) button_group.setLayout(button_layout) layout.addWidget(button_group) # Interpolate button self.interpolate_btn = QPushButton("Interpolate") self.interpolate_btn.clicked.connect(self.interpolate_stack) layout.addWidget(self.interpolate_btn) self.setLayout(layout) def select_input(self): try: if self.dir_radio.isChecked(): path = QFileDialog.getExistingDirectory( self, "Select Directory with Images", "", QFileDialog.ShowDirsOnly | QFileDialog.DontResolveSymlinks ) else: path, _ = QFileDialog.getOpenFileName( self, "Select TIFF Stack", "", "TIFF Files (*.tif *.tiff)", options=QFileDialog.Options() ) if path: self.input_path = path self.input_label.setText(f"Selected: {os.path.basename(path)}") self.input_label.setToolTip(path) QApplication.processEvents() except Exception as e: QMessageBox.critical(self, "Error", f"Error selecting input: {str(e)}") def select_output(self): try: directory = QFileDialog.getExistingDirectory( self, "Select Output Directory", "", QFileDialog.ShowDirsOnly | QFileDialog.DontResolveSymlinks ) if directory: self.output_directory = directory self.output_label.setText(f"Selected: {os.path.basename(directory)}") self.output_label.setToolTip(directory) QApplication.processEvents() except Exception as e: QMessageBox.critical(self, "Error", f"Error selecting output directory: {str(e)}") def load_images(self): try: progress = QProgressDialog("Loading images...", "Cancel", 0, 100, self) progress.setWindowModality(Qt.WindowModal) progress.show() QApplication.processEvents() if self.stack_radio.isChecked(): progress.setLabelText("Loading TIFF stack...") progress.setValue(20) QApplication.processEvents() # Load stack preserving original dtype stack = io.imread(self.input_path) print(f"Loaded stack dtype: {stack.dtype}") print(f"Value range: [{stack.min()}, {stack.max()}]") progress.setValue(90) QApplication.processEvents() return stack else: valid_extensions = ('.png', '.jpg', '.jpeg', '.bmp', '.tif', '.tiff') files = sorted([f for f in os.listdir(self.input_path) if f.lower().endswith(valid_extensions)]) if not files: raise ValueError("No valid image files found in directory") progress.setMaximum(len(files)) # Load first image to get dimensions and dtype first_img = io.imread(os.path.join(self.input_path, files[0])) stack = np.zeros((len(files), *first_img.shape), dtype=first_img.dtype) stack[0] = first_img print(f"Created stack with dtype: {stack.dtype}") print(f"First image range: [{first_img.min()}, {first_img.max()}]") # Load remaining images for i, fname in enumerate(files[1:], 1): progress.setValue(i) progress.setLabelText(f"Loading image {i+1}/{len(files)}") QApplication.processEvents() if progress.wasCanceled(): raise InterruptedError("Loading cancelled by user") img = io.imread(os.path.join(self.input_path, fname)) if img.shape != first_img.shape: raise ValueError(f"Image {fname} has different dimensions from the first image") if img.dtype != first_img.dtype: raise ValueError(f"Image {fname} has different bit depth from the first image") stack[i] = img return stack except Exception as e: raise ValueError(f"Error loading images: {str(e)}") finally: progress.close() QApplication.processEvents() def interpolate_stack(self): if not self.input_path or not self.output_directory: QMessageBox.warning(self, "Missing Paths", "Please select both input and output paths") return try: # Create progress dialog progress = QProgressDialog("Processing...", "Cancel", 0, 100, self) progress.setWindowModality(Qt.WindowModal) progress.setWindowTitle("Interpolation Progress") progress.setMinimumDuration(0) progress.setMinimumWidth(400) progress.show() QApplication.processEvents() # Load images progress.setLabelText("Loading images...") progress.setValue(10) QApplication.processEvents() input_stack = self.load_images() original_dtype = input_stack.dtype type_range = np.iinfo(original_dtype) if np.issubdtype(original_dtype, np.integer) else None print(f"Original data type: {original_dtype}") print(f"Original shape: {input_stack.shape}") print(f"Original range: {input_stack.min()} - {input_stack.max()}") # Normalize input data to float64 for interpolation input_stack_normalized = input_stack.astype(np.float64) if type_range is not None: input_stack_normalized = input_stack_normalized / type_range.max progress.setLabelText("Calculating dimensions...") progress.setValue(20) QApplication.processEvents() # Calculate dimensions and coordinates z_old = np.arange(input_stack.shape[0]) * self.orig_z_size.value() y_old = np.arange(input_stack.shape[1]) * self.orig_xy_size.value() x_old = np.arange(input_stack.shape[2]) * self.orig_xy_size.value() z_new = np.arange(z_old[0], z_old[-1] + self.new_z_size.value(), self.new_z_size.value()) y_new = np.arange(0, input_stack.shape[1] * self.orig_xy_size.value(), self.new_xy_size.value()) x_new = np.arange(0, input_stack.shape[2] * self.orig_xy_size.value(), self.new_xy_size.value()) y_new = y_new[y_new < y_old[-1] + self.new_xy_size.value()] x_new = x_new[x_new < x_old[-1] + self.new_xy_size.value()] new_shape = (len(z_new), len(y_new), len(x_new)) print(f"New dimensions will be: {new_shape}") # Initialize output array interpolated_data = np.zeros(new_shape, dtype=np.float64) method = self.method_combo.currentText() # For higher-order methods, use a hybrid approach if method in ['cubic', 'quintic', 'pchip']: progress.setLabelText("Using hybrid interpolation approach...") progress.setValue(30) QApplication.processEvents() from scipy.interpolate import interp1d # Process each XY point total_points = input_stack.shape[1] * input_stack.shape[2] points_processed = 0 temp_stack = np.zeros((len(z_new), input_stack.shape[1], input_stack.shape[2]), dtype=np.float64) for y in range(input_stack.shape[1]): for x in range(input_stack.shape[2]): if progress.wasCanceled(): return points_processed += 1 if points_processed % 1000 == 0: progress_val = 30 + (points_processed / total_points * 30) progress.setValue(int(progress_val)) progress.setLabelText(f"Interpolating Z dimension: {points_processed}/{total_points} points") QApplication.processEvents() z_profile = input_stack_normalized[:, y, x] f = interp1d(z_old, z_profile, kind=method, bounds_error=False, fill_value='extrapolate') temp_stack[:, y, x] = f(z_new) progress.setLabelText("Interpolating XY planes...") progress.setValue(60) QApplication.processEvents() for z in range(len(z_new)): if progress.wasCanceled(): return progress.setValue(60 + int((z / len(z_new)) * 30)) progress.setLabelText(f"Processing XY plane {z+1}/{len(z_new)}") QApplication.processEvents() interpolator = RegularGridInterpolator( (y_old, x_old), temp_stack[z], method='linear', bounds_error=False, fill_value=0 ) yy, xx = np.meshgrid(y_new, x_new, indexing='ij') pts = np.stack([yy.ravel(), xx.ravel()], axis=-1) interpolated_data[z] = interpolator(pts).reshape(len(y_new), len(x_new)) del temp_stack else: # For linear and nearest neighbor progress.setLabelText("Creating interpolator...") progress.setValue(30) QApplication.processEvents() interpolator = RegularGridInterpolator( (z_old, y_old, x_old), input_stack_normalized, method=method, bounds_error=False, fill_value=0 ) slices_per_batch = max(1, len(z_new) // 20) total_batches = (len(z_new) + slices_per_batch - 1) // slices_per_batch for batch_idx in range(total_batches): if progress.wasCanceled(): return start_idx = batch_idx * slices_per_batch end_idx = min((batch_idx + 1) * slices_per_batch, len(z_new)) progress.setLabelText(f"Interpolating batch {batch_idx + 1}/{total_batches}") progress_value = int(40 + (batch_idx/total_batches)*40) progress.setValue(progress_value) QApplication.processEvents() zz, yy, xx = np.meshgrid( z_new[start_idx:end_idx], y_new, x_new, indexing='ij' ) pts = np.stack([zz.ravel(), yy.ravel(), xx.ravel()], axis=-1) interpolated_data[start_idx:end_idx] = interpolator(pts).reshape( end_idx - start_idx, len(y_new), len(x_new) ) # Convert back to original dtype progress.setLabelText("Converting to original bit depth...") progress.setValue(90) QApplication.processEvents() if np.issubdtype(original_dtype, np.integer): # Scale back to original range interpolated_data = np.clip(interpolated_data, 0, 1) interpolated_data = (interpolated_data * type_range.max).astype(original_dtype) else: interpolated_data = interpolated_data.astype(original_dtype) print(f"Final dtype: {interpolated_data.dtype}") print(f"Final range: [{interpolated_data.min()}, {interpolated_data.max()}]") # Save output progress.setLabelText("Saving interpolated stack...") progress.setValue(95) QApplication.processEvents() if self.stack_radio.isChecked(): output_name = os.path.splitext(os.path.basename(self.input_path))[0] else: output_name = "interpolated_stack" output_path = os.path.join(self.output_directory, f"{output_name}_interpolated.tif") # Convert sizes to micrometers for metadata unit = self.size_unit.currentText() xy_size = self.new_xy_size.value() z_size = self.new_z_size.value() if unit == "nm": xy_size /= 1000 z_size /= 1000 elif unit == "mm": xy_size *= 1000 z_size *= 1000 # Save with metadata tifffile.imwrite( output_path, interpolated_data, imagej=True, metadata={ 'axes': 'ZYX', 'spacing': z_size, 'unit': 'um', 'finterval': xy_size }, resolution=(1.0/xy_size, 1.0/xy_size) ) progress.setValue(100) QApplication.processEvents() QMessageBox.information( self, "Success", f"Interpolation completed successfully!\n" f"Output saved to:\n{output_path}\n" f"New dimensions: {interpolated_data.shape}\n" f"Bit depth: {interpolated_data.dtype}\n" f"XY Pixel size: {self.new_xy_size.value()} {unit}\n" f"Z Spacing: {self.new_z_size.value()} {unit}" ) except Exception as e: QMessageBox.critical(self, "Error", str(e)) print(f"Error occurred: {str(e)}") import traceback traceback.print_exc() finally: progress.close() QApplication.processEvents() def show_centered(self, parent): parent_geo = parent.geometry() self.move(parent_geo.center() - self.rect().center()) self.show() QApplication.processEvents() # Ensure UI updates # Helper function to create the dialog def show_stack_interpolator(parent): dialog = StackInterpolator(parent) dialog.show_centered(parent) return dialog ================================================ FILE: src/digitalsreeni_image_annotator/stack_to_slices.py ================================================ import os import numpy as np from PyQt5.QtWidgets import (QDialog, QVBoxLayout, QHBoxLayout, QPushButton, QFileDialog, QLabel, QMessageBox, QComboBox, QGridLayout, QWidget, QProgressDialog, QApplication) from PyQt5.QtCore import Qt from tifffile import TiffFile from czifile import CziFile from PIL import Image class DimensionDialog(QDialog): def __init__(self, shape, file_name, parent=None): super().__init__(parent) self.setWindowTitle("Assign Dimensions") self.shape = shape self.initUI(file_name) def initUI(self, file_name): layout = QVBoxLayout() file_name_label = QLabel(f"File: {file_name}") file_name_label.setWordWrap(True) layout.addWidget(file_name_label) dim_widget = QWidget() dim_layout = QGridLayout(dim_widget) self.combos = [] dimensions = ['T', 'Z', 'C', 'S', 'H', 'W'] for i, dim in enumerate(self.shape): dim_layout.addWidget(QLabel(f"Dimension {i} (size {dim}):"), i, 0) combo = QComboBox() combo.addItems(dimensions) dim_layout.addWidget(combo, i, 1) self.combos.append(combo) layout.addWidget(dim_widget) self.button = QPushButton("OK") self.button.clicked.connect(self.accept) layout.addWidget(self.button) self.setLayout(layout) def get_dimensions(self): return [combo.currentText() for combo in self.combos] class StackToSlicesDialog(QDialog): def __init__(self, parent=None): super().__init__(parent) self.setWindowTitle("Stack to Slices") self.setGeometry(100, 100, 400, 200) self.setWindowFlags(self.windowFlags() | Qt.Window) self.setWindowModality(Qt.ApplicationModal) self.dimensions = None self.initUI() def initUI(self): layout = QVBoxLayout() self.file_label = QLabel("No file selected") layout.addWidget(self.file_label) select_button = QPushButton("Select Stack File") select_button.clicked.connect(self.select_file) layout.addWidget(select_button) self.convert_button = QPushButton("Convert to Slices") self.convert_button.clicked.connect(self.convert_to_slices) self.convert_button.setEnabled(False) layout.addWidget(self.convert_button) self.setLayout(layout) def select_file(self): self.file_name, _ = QFileDialog.getOpenFileName(self, "Select Stack File", "", "Image Files (*.tif *.tiff *.czi)") if self.file_name: self.file_label.setText(f"Selected file: {os.path.basename(self.file_name)}") QApplication.processEvents() self.process_file() def process_file(self): if self.file_name.lower().endswith(('.tif', '.tiff')): self.process_tiff() elif self.file_name.lower().endswith('.czi'): self.process_czi() def process_tiff(self): with TiffFile(self.file_name) as tif: image_array = tif.asarray() self.get_dimensions(image_array.shape) def process_czi(self): with CziFile(self.file_name) as czi: image_array = czi.asarray() self.get_dimensions(image_array.shape) def get_dimensions(self, shape): dialog = DimensionDialog(shape, os.path.basename(self.file_name), self) dialog.setWindowModality(Qt.ApplicationModal) if dialog.exec_(): self.dimensions = dialog.get_dimensions() self.convert_button.setEnabled(True) else: self.dimensions = None self.convert_button.setEnabled(False) QApplication.processEvents() def convert_to_slices(self): if not hasattr(self, 'file_name') or not self.dimensions: QMessageBox.warning(self, "Invalid Input", "Please select a file and assign dimensions first.") return output_dir = QFileDialog.getExistingDirectory(self, "Select Output Directory") if not output_dir: return if self.file_name.lower().endswith(('.tif', '.tiff')): with TiffFile(self.file_name) as tif: image_array = tif.asarray() elif self.file_name.lower().endswith('.czi'): with CziFile(self.file_name) as czi: image_array = czi.asarray() self.save_slices(image_array, output_dir) def save_slices(self, image_array, output_dir): base_name = os.path.splitext(os.path.basename(self.file_name))[0] slice_indices = [i for i, dim in enumerate(self.dimensions) if dim not in ['H', 'W']] total_slices = np.prod([image_array.shape[i] for i in slice_indices]) progress = QProgressDialog("Saving slices...", "Cancel", 0, total_slices, self) progress.setWindowModality(Qt.WindowModal) progress.setWindowTitle("Progress") progress.setMinimumDuration(0) progress.setValue(0) progress.show() try: for idx, _ in enumerate(np.ndindex(tuple(image_array.shape[i] for i in slice_indices))): if progress.wasCanceled(): break full_idx = [slice(None)] * len(self.dimensions) for i, val in zip(slice_indices, _): full_idx[i] = val slice_array = image_array[tuple(full_idx)] if slice_array.ndim > 2: slice_array = slice_array.squeeze() if slice_array.dtype == np.uint16: mode = 'I;16' elif slice_array.dtype == np.uint8: mode = 'L' else: slice_array = ((slice_array - slice_array.min()) / (slice_array.max() - slice_array.min()) * 65535).astype(np.uint16) mode = 'I;16' slice_name = f"{base_name}_{'_'.join([f'{self.dimensions[i]}{val+1}' for i, val in zip(slice_indices, _)])}.png" img = Image.fromarray(slice_array, mode=mode) img.save(os.path.join(output_dir, slice_name)) progress.setValue(idx + 1) QApplication.processEvents() if progress.wasCanceled(): QMessageBox.warning(self, "Conversion Interrupted", "The conversion process was interrupted.") else: QMessageBox.information(self, "Conversion Complete", f"All slices have been saved to {output_dir}") finally: progress.close() def show_centered(self, parent): parent_geo = parent.geometry() self.move(parent_geo.center() - self.rect().center()) self.show() def show_stack_to_slices(parent): dialog = StackToSlicesDialog(parent) dialog.show_centered(parent) return dialog ================================================ FILE: src/digitalsreeni_image_annotator/utils.py ================================================ """ Utility functions for the Image Annotator application. This module contains helper functions used across the application. @DigitalSreeni Dr. Sreenivas Bhattiprolu """ import numpy as np def calculate_area(annotation): if "segmentation" in annotation and annotation["segmentation"] is not None: # Polygon area x, y = annotation["segmentation"][0::2], annotation["segmentation"][1::2] return 0.5 * abs(sum(x[i] * y[i+1] - x[i+1] * y[i] for i in range(-1, len(x)-1))) elif "bbox" in annotation: # Rectangle area x, y, w, h = annotation["bbox"] return w * h return 0 def calculate_bbox(segmentation): x_coordinates, y_coordinates = segmentation[0::2], segmentation[1::2] x_min, y_min = min(x_coordinates), min(y_coordinates) x_max, y_max = max(x_coordinates), max(y_coordinates) width, height = x_max - x_min, y_max - y_min return [x_min, y_min, width, height] def normalize_image(image_array): """Normalize image array to 8-bit range.""" if image_array.dtype != np.uint8: image_array = ((image_array - image_array.min()) / (image_array.max() - image_array.min()) * 255).astype(np.uint8) return image_array ================================================ FILE: src/digitalsreeni_image_annotator/yolo_trainer.py ================================================ import os from ultralytics import YOLO from PyQt5.QtWidgets import QFileDialog, QMessageBox from PyQt5.QtWidgets import (QDialog, QVBoxLayout, QHBoxLayout, QPushButton, QLineEdit, QLabel, QFileDialog, QDialogButtonBox) import yaml import numpy as np from pathlib import Path from .export_formats import export_yolo_v5plus from collections import deque from PyQt5.QtWidgets import QDialog, QVBoxLayout, QTextEdit, QPushButton from PyQt5.QtCore import Qt, pyqtSignal, QObject class TrainingInfoDialog(QDialog): stop_signal = pyqtSignal() def __init__(self, parent=None): super().__init__(parent) self.setWindowTitle("Training Progress") self.setModal(False) self.layout = QVBoxLayout(self) self.info_text = QTextEdit(self) self.info_text.setReadOnly(True) self.layout.addWidget(self.info_text) self.stop_button = QPushButton("Stop Training", self) self.stop_button.clicked.connect(self.stop_training) self.layout.addWidget(self.stop_button) self.close_button = QPushButton("Close", self) self.close_button.clicked.connect(self.hide) self.layout.addWidget(self.close_button) self.setMinimumSize(400, 300) def update_info(self, text): self.info_text.append(text) self.info_text.verticalScrollBar().setValue(self.info_text.verticalScrollBar().maximum()) def stop_training(self): self.stop_signal.emit() self.stop_button.setEnabled(False) self.stop_button.setText("Stopping...") def closeEvent(self, event): event.ignore() self.hide() class LoadPredictionModelDialog(QDialog): def __init__(self, parent=None): super().__init__(parent) self.setWindowTitle("Load Prediction Model and YAML") self.model_path = "" self.yaml_path = "" layout = QVBoxLayout(self) # Model file selection model_layout = QHBoxLayout() self.model_edit = QLineEdit() model_button = QPushButton("Browse") model_button.clicked.connect(self.browse_model) model_layout.addWidget(QLabel("Model File:")) model_layout.addWidget(self.model_edit) model_layout.addWidget(model_button) layout.addLayout(model_layout) # YAML file selection yaml_layout = QHBoxLayout() self.yaml_edit = QLineEdit() yaml_button = QPushButton("Browse") yaml_button.clicked.connect(self.browse_yaml) yaml_layout.addWidget(QLabel("YAML File:")) yaml_layout.addWidget(self.yaml_edit) yaml_layout.addWidget(yaml_button) layout.addLayout(yaml_layout) # OK and Cancel buttons self.button_box = QDialogButtonBox(QDialogButtonBox.Ok | QDialogButtonBox.Cancel) self.button_box.accepted.connect(self.accept) self.button_box.rejected.connect(self.reject) layout.addWidget(self.button_box) def browse_model(self): file_name, _ = QFileDialog.getOpenFileName(self, "Select YOLO Model", "", "YOLO Model (*.pt)") if file_name: self.model_path = file_name self.model_edit.setText(file_name) def browse_yaml(self): file_name, _ = QFileDialog.getOpenFileName(self, "Select YAML File", "", "YAML Files (*.yaml *.yml)") if file_name: self.yaml_path = file_name self.yaml_edit.setText(file_name) class YOLOTrainer(QObject): progress_signal = pyqtSignal(str) def __init__(self, project_dir, main_window): super().__init__() self.project_dir = project_dir self.main_window = main_window self.model = None self.dataset_path = os.path.join(project_dir, "yolo_dataset") self.model_path = os.path.join(project_dir, "yolo_model") self.yaml_path = None self.yaml_data = None self.epoch_info = deque(maxlen=10) self.progress_callback = None self.total_epochs = None self.conf_threshold = 0.25 self.stop_training = False self.class_names = None def load_model(self, model_path=None): if model_path is None: model_path, _ = QFileDialog.getOpenFileName(self.main_window, "Select YOLO Model", "", "YOLO Model (*.pt)") if model_path: try: self.model = YOLO(model_path) return True except Exception as e: QMessageBox.critical(self.main_window, "Error Loading Model", f"Could not load the model. Error: {str(e)}") return False def prepare_dataset(self): output_dir, yaml_path = export_yolo_v5plus( self.main_window.all_annotations, self.main_window.class_mapping, self.main_window.image_paths, self.main_window.slices, self.main_window.image_slices, self.dataset_path ) yaml_path = Path(yaml_path) with yaml_path.open('r') as f: yaml_content = yaml.safe_load(f) # Update paths for new YOLO v5+ structure yaml_content['train'] = 'images/train' # Changed from train/images yaml_content['val'] = 'images/val' # Changed from train/images yaml_content['test'] = '../test/images' with yaml_path.open('w') as f: yaml.dump(yaml_content, f, default_flow_style=False) self.yaml_path = str(yaml_path) return self.yaml_path def load_yaml(self, yaml_path=None): if yaml_path is None: yaml_path, _ = QFileDialog.getOpenFileName(self.main_window, "Select YOLO Dataset YAML", "", "YAML Files (*.yaml *.yml)") if yaml_path and os.path.exists(yaml_path): with open(yaml_path, 'r') as f: try: yaml_data = yaml.safe_load(f) print(f"Loaded YAML contents: {yaml_data}") # Ensure paths are relative for key in ['train', 'val', 'test']: if key in yaml_data and os.path.isabs(yaml_data[key]): yaml_data[key] = os.path.relpath(yaml_data[key], start=os.path.dirname(yaml_path)) print(f"Updated YAML contents: {yaml_data}") # Save the updated YAML data self.yaml_data = yaml_data self.yaml_path = yaml_path # Write the updated YAML back to the file with open(yaml_path, 'w') as f: yaml.dump(yaml_data, f, default_flow_style=False) return True except yaml.YAMLError as e: QMessageBox.critical(self.main_window, "Error Loading YAML", f"Invalid YAML file. Error: {str(e)}") return False def on_train_epoch_end(self, trainer): epoch = trainer.epoch + 1 # Add 1 to start from 1 instead of 0 total_epochs = trainer.epochs loss = trainer.loss.item() progress_text = f"Epoch {epoch}/{total_epochs}, Loss: {loss:.4f}" # Only emit the signal, don't call the callback directly self.progress_signal.emit(progress_text) if self.stop_training: trainer.model.stop = True self.stop_training = False return False return True def train_model(self, epochs=100, imgsz=640): if self.model is None: raise ValueError("No model loaded. Please load a model first.") if self.yaml_path is None or not Path(self.yaml_path).exists(): raise FileNotFoundError("Dataset YAML not found. Please prepare or load a dataset first.") self.stop_training = False self.total_epochs = epochs self.epoch_info.clear() # Add the callback self.model.add_callback("on_train_epoch_end", self.on_train_epoch_end) try: yaml_path = Path(self.yaml_path) yaml_dir = yaml_path.parent print(f"Training with YAML: {yaml_path}") print(f"YAML directory: {yaml_dir}") with yaml_path.open('r') as f: yaml_content = yaml.safe_load(f) print(f"YAML content: {yaml_content}") # For now, use train as val since we don't have separate validation set train_dir = str(yaml_dir / 'images' / 'train') # Update YAML content with correct paths yaml_content['train'] = train_dir yaml_content['val'] = train_dir # Use same directory for validation # Create the val directory structure if it doesn't exist val_img_dir = yaml_dir / 'images' / 'val' val_label_dir = yaml_dir / 'labels' / 'val' val_img_dir.mkdir(parents=True, exist_ok=True) val_label_dir.mkdir(parents=True, exist_ok=True) # Write updated YAML with adjusted paths temp_yaml_path = yaml_dir / 'temp_train.yaml' with temp_yaml_path.open('w') as f: yaml.dump(yaml_content, f, default_flow_style=False) print(f"Training with updated YAML: {temp_yaml_path}") print(f"Updated YAML content: {yaml_content}") results = self.model.train(data=str(temp_yaml_path), epochs=epochs, imgsz=imgsz) return results finally: # Clear the callback self.model.callbacks["on_train_epoch_end"] = [] # Remove temporary YAML file if 'temp_yaml_path' in locals(): temp_yaml_path.unlink(missing_ok=True) def verify_dataset_structure(self): yaml_path = Path(self.yaml_path) yaml_dir = yaml_path.parent with yaml_path.open('r') as f: yaml_content = yaml.safe_load(f) # Use paths from YAML content train_images_dir = yaml_dir / yaml_content.get('train', 'images/train') val_images_dir = yaml_dir / yaml_content.get('val', 'images/val') train_labels_dir = yaml_dir / 'labels' / 'train' # Labels directory corresponds to images val_labels_dir = yaml_dir / 'labels' / 'val' # Labels directory corresponds to images # Check both train and val directories missing_dirs = [] if not train_images_dir.exists(): missing_dirs.append(f"Training images directory: {train_images_dir}") if not train_labels_dir.exists(): missing_dirs.append(f"Training labels directory: {train_labels_dir}") if not val_images_dir.exists(): missing_dirs.append(f"Validation images directory: {val_images_dir}") if not val_labels_dir.exists(): missing_dirs.append(f"Validation labels directory: {val_labels_dir}") if missing_dirs: raise FileNotFoundError(f"The following directories were not found:\n" + "\n".join(missing_dirs)) print(f"Dataset structure verified:") print(f"Train images: {train_images_dir}") print(f"Train labels: {train_labels_dir}") print(f"Val images: {val_images_dir}") print(f"Val labels: {val_labels_dir}") def check_ultralytics_settings(self): settings_path = Path.home() / ".config" / "Ultralytics" / "settings.yaml" if settings_path.exists(): with settings_path.open('r') as f: settings = yaml.safe_load(f) print(f"Ultralytics settings: {settings}") else: print("Ultralytics settings file not found.") def stop_training_signal(self): self.stop_training = True self.progress_signal.emit("Stopping training...") def set_progress_callback(self, callback): self.progress_callback = callback def stop_training_callback(self, trainer): if getattr(self, 'stop_training', False): trainer.model.stop = True self.stop_training = False def on_epoch_end(self, trainer): # Get current epoch epoch = trainer.epoch if hasattr(trainer, 'epoch') else trainer.current_epoch # Get total epochs total_epochs = self.total_epochs # Use the value we set in train_model # Get loss if hasattr(trainer, 'metrics') and 'train/box_loss' in trainer.metrics: loss = trainer.metrics['train/box_loss'] elif hasattr(trainer, 'loss'): loss = trainer.loss else: loss = 0 # Default value if loss can't be found # Ensure loss is a number loss = float(loss) info = f"Epoch {epoch}/{total_epochs}, Loss: {loss:.4f}" self.epoch_info.append(info) display_text = f"Current Progress:\n" + "\n".join(self.epoch_info) if self.progress_callback: self.progress_callback(display_text) def save_model(self): if self.model is None: raise ValueError("No model to save. Please train a model first.") save_path, _ = QFileDialog.getSaveFileName(self.main_window, "Save YOLO Model", "", "YOLO Model (*.pt)") if save_path: self.model.export(save_path) return True return False def load_prediction_model(self, model_path, yaml_path): try: self.model = YOLO(model_path) with open(yaml_path, 'r') as f: self.prediction_yaml = yaml.safe_load(f) if 'names' not in self.prediction_yaml: raise ValueError("The YAML file does not contain a 'names' section for class names.") self.class_names = self.prediction_yaml['names'] print(f"Loaded class names: {self.class_names}") # Verify that the number of classes in the YAML matches the model if len(self.class_names) != len(self.model.names): mismatch_message = (f"Warning: Number of classes in YAML ({len(self.class_names)}) " f"does not match the model ({len(self.model.names)}). " "This may cause issues during prediction.") print(mismatch_message) return True, mismatch_message return True, None except Exception as e: error_message = f"Error loading model or YAML: {str(e)}" print(error_message) return False, error_message def predict(self, input_data): if self.model is None: raise ValueError("No model loaded. Please load a model first.") if isinstance(input_data, str): # It's a file path results = self.model(input_data, task='segment', conf=self.conf_threshold, save=False, show=False) elif isinstance(input_data, np.ndarray): # It's a numpy array results = self.model(input_data, task='segment', conf=self.conf_threshold, save=False, show=False) else: raise ValueError("Invalid input type. Expected file path or numpy array.") # Get the input size used for prediction and the original image size input_size = results[0].orig_shape original_size = results[0].orig_img.shape[:2] return results, input_size, original_size def set_conf_threshold(self, conf): self.conf_threshold = conf