[
  {
    "path": "README.md",
    "content": "# 3D肺结节检测系统\n\n## 0. 项目结构\n\n```\n├── data/               # 数据处理相关\n│   ├── dataclass/      # 数据类定义，包含NoduleCube结节立方体类\n│   └── preprocessing/  # 数据预处理代码，包括LUNA16数据处理\n├── models/             # 模型定义\n│   └── pytorch_c3d_tiny.py  # Tiny-C3D模型定义\n├── training/           # 模型训练相关\n│   ├── pytorch_logs/   # 训练日志\n│   ├── pytorch_checkpoints/ # 模型检查点\n│   └── train_c3d_pytorch.py # 训练脚本\n├── inference/          # 模型推理相关\n│   ├── classifier.py   # 分类器实现\n│   ├── detector.py     # 检测器实现\n│   └── pytorch_nodule_detector.py # PyTorch实现的结节检测器\n├── deploy/             # 部署相关\n│   ├── backend/        # 后端代码\n│   ├── frontend/       # 前端代码\n│   └── run.py          # 启动脚本\n└── util/               # 工具函数\n```\n\n## 1. Tiny-C3D模型架构设计\n\nTiny-C3D是一个轻量级的3D卷积神经网络，专为肺结节分类设计。模型采用了C3D架构的简化版本，保留核心功能的同时大幅减少参数量，使其适合在资源受限环境中运行。\n原始的C3D模型是用于视频分类的 [Learning Spatiotemporal Features with 3D Convolutional Networks](https://arxiv.org/abs/1412.0767)\n\n### 模型结构\n\n模型包含4个3D卷积块，每个卷积块由以下组件构成：\n- 3D卷积层\n- 批归一化层\n- ReLU激活函数\n- 最大池化层\n- Dropout层（防止过拟合）\n\n输入数据为32×32×32的体素立方体，模型结构如下\n\n```commandline\nC3dTiny(\n  (conv_block1): Sequential(\n    (0): Conv3d(1, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))\n    (1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n    (2): ReLU()\n    (3): MaxPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2), padding=0, dilation=1, ceil_mode=False)\n  )\n  (conv_block2): Sequential(\n    (0): Conv3d(64, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))\n    (1): BatchNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n    (2): ReLU()\n    (3): MaxPool3d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n  )\n  (drop_out1): Dropout(p=0.2, inplace=False)\n  (conv_block3): Sequential(\n    (0): Conv3d(128, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))\n    (1): BatchNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n    (2): ReLU()\n    (3): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))\n    (4): BatchNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n    (5): ReLU()\n    (6): MaxPool3d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n  )\n  (drop_out2): Dropout(p=0.2, inplace=False)\n  (conv_block4): Sequential(\n    (0): Conv3d(256, 512, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))\n    (1): BatchNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n    (2): ReLU()\n    (3): Conv3d(512, 512, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))\n    (4): BatchNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n    (5): ReLU()\n    (6): MaxPool3d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n  )\n  (drop_out3): Dropout(p=0.2, inplace=False)\n  (flatten): Flatten(start_dim=1, end_dim=-1)\n  (fc1): Sequential(\n    (0): Linear(in_features=8192, out_features=512, bias=True)\n    (1): ReLU()\n  )\n  (fc2): Linear(in_features=512, out_features=2, bias=True)\n)\n```\n\n\n## 2. 数据准备与预处理\n\n### 2.1 数据来源\n\n本项目使用[LUNA16（Lung Nodule Analysis 2016）](https://luna16.grand-challenge.org/)数据集，该数据集包含肺CT扫描图像和结节位置标注。\n\n### 2.2 标注数据处理\n\n- 恶性结节坐标标注数据使用的是 `annotations.csv`.最大结节半径为16，所以当前抽取的cube边长为32\n- 良性结节坐标标注数据使用的是 `candidates_V2.csv`。注意，这部分标注数据可能是随机生成的，需要做筛选处理，当前代码中`luna16_prepare_cube_data.py`执行 `get_real_candidate`将标注数据筛选后重新保存。\n\n\n### 2.3 cube预处理流程\n\n1. **结节立方体提取**:\n   - 从原始CT图像中提取以结节为中心的32×32×32立方体\n   - 根据标注信息区分良性与恶性结节\n\n2. **数据归一化**:\n   - 通过`normal_cube_to_tensor`函数将数据归一化到[0,1]范围\n   - 修复无效值，如NaN和Inf\n\n3. **数据增强(只对良性结节)**:\n   - 随机旋转：在三个轴向上随机旋转±20度\n   - 随机翻转：沿指定轴随机翻转\n   - 高斯噪声：添加低强度高斯噪声增强模型鲁棒性\n\n4. **数据平衡**:\n   - 采样相等数量的正负样本，避免类别不平衡\n   - 负样本采样量为正样本的两倍，提高模型对负样本的敏感度。此问题没有完全解决，最终模型仍然存在大量误判，需要对结果根据坐标进一步筛选\n\n### 2.4 数据类设计\n\n#### 2.4.1 CTData\n\n这个类主要负责 将从DICOM或MHD格式的原始CT数据，经过HU转换，缩放到统一的0-1之间的标准像素数据。\n\n#### 2.4.2 NoduleCube\n\n主要从直接从0-1之间的标准像素数据的多维数组中抽取 指定边长的Cube数据，保存和加载npy，保存和加载png图像数据等。\n\n## 3. 模型训练与参数设置\n\n### 3.1 训练参数\n\n- **批量大小**: 64。 这个参考自己的GPU显存设置，此为 nvidia 4070 显卡参数\n- **学习率**: 5e-4\n- **权重衰减**: 1e-5\n- **优化器**: Adam\n- **损失函数**: 交叉熵损失\n- **训练轮次**: 15\n- **学习率调度**: ReduceLROnPlateau\n- **梯度裁剪**: 1.0\n\n### 3.2 训练流程\n\n训练代码为 `training/train_c3d_pytorch.py`, 设置自己的数据目录\n\n1. 加载预处理后的正负样本\n2. 按8:2比例分割训练集和验证集\n3. 数据增强提高模型泛化能力\n4. 每轮训练后在验证集上评估\n5. 保存最佳验证准确率模型\n\n### 训练结果\n\n训练日志如下\n\n```commandline\n2025-03-31 22:33:51,563 - c3d_training - INFO - Epoch [1/15], Train Loss: 0.3146, Train Acc: 89.05%, Val Loss: 0.1623, Val Acc: 93.28%, Time: 80.81s\n2025-03-31 22:35:12,255 - c3d_training - INFO - Epoch [2/15], Train Loss: 0.0834, Train Acc: 97.17%, Val Loss: 0.0728, Val Acc: 97.33%, Time: 80.39s\n2025-03-31 22:36:33,755 - c3d_training - INFO - Epoch [3/15], Train Loss: 0.0504, Train Acc: 98.30%, Val Loss: 0.0383, Val Acc: 98.61%, Time: 81.20s\n2025-03-31 22:37:55,955 - c3d_training - INFO - Epoch [4/15], Train Loss: 0.0368, Train Acc: 98.80%, Val Loss: 0.0522, Val Acc: 98.22%, Time: 81.89s\n2025-03-31 22:39:16,050 - c3d_training - INFO - Epoch [5/15], Train Loss: 0.0282, Train Acc: 99.04%, Val Loss: 0.0322, Val Acc: 98.98%, Time: 79.84s\n2025-03-31 22:40:37,157 - c3d_training - INFO - Epoch [6/15], Train Loss: 0.0244, Train Acc: 99.13%, Val Loss: 0.0393, Val Acc: 98.83%, Time: 80.80s\n2025-03-31 22:41:59,029 - c3d_training - INFO - Epoch [7/15], Train Loss: 0.0204, Train Acc: 99.40%, Val Loss: 0.0383, Val Acc: 98.95%, Time: 81.64s\n2025-03-31 22:43:20,016 - c3d_training - INFO - Epoch [8/15], Train Loss: 0.0219, Train Acc: 99.31%, Val Loss: 0.0578, Val Acc: 98.29%, Time: 80.75s\n2025-03-31 22:44:41,369 - c3d_training - INFO - Epoch [9/15], Train Loss: 0.0171, Train Acc: 99.52%, Val Loss: 0.0342, Val Acc: 99.17%, Time: 81.13s\n2025-03-31 22:46:02,745 - c3d_training - INFO - Epoch [10/15], Train Loss: 0.0173, Train Acc: 99.49%, Val Loss: 0.0279, Val Acc: 99.15%, Time: 81.04s\n2025-03-31 22:47:24,128 - c3d_training - INFO - Epoch [11/15], Train Loss: 0.0176, Train Acc: 99.47%, Val Loss: 0.0349, Val Acc: 99.17%, Time: 81.15s\n2025-03-31 22:48:44,717 - c3d_training - INFO - Epoch [12/15], Train Loss: 0.0140, Train Acc: 99.53%, Val Loss: 0.0362, Val Acc: 99.27%, Time: 80.36s\n2025-03-31 22:50:07,041 - c3d_training - INFO - Epoch [13/15], Train Loss: 0.0126, Train Acc: 99.62%, Val Loss: 0.0497, Val Acc: 98.87%, Time: 82.02s\n2025-03-31 22:51:27,896 - c3d_training - INFO - Epoch [14/15], Train Loss: 0.0152, Train Acc: 99.49%, Val Loss: 0.0271, Val Acc: 99.32%, Time: 80.61s\n2025-03-31 22:52:49,409 - c3d_training - INFO - Epoch [15/15], Train Loss: 0.0127, Train Acc: 99.58%, Val Loss: 0.0277, Val Acc: 99.15%, Time: 81.20s\n\n```\n\n损失函数和准确率随着epoch的示意图如下\n\n![metric](./training/pytorch_checkpoints/training_metrics.png)\n\n## 4. 模型推理\n\n推理系统设计为两阶段流程：\n\n### 1. 结节检测\n\n- 使用滑动窗口技术扫描完整CT体积\n- 检测潜在结节位置\n- 非极大值抑制合并重叠检测结果\n\n### 2. 结节分类\n\n- 对检测到的候选区域提取特征\n- 应用训练好的Tiny-C3D模型\n- 输出结节概率和恶性度评分\n\n### 推理优化\n\n- 批处理推理提高处理效率\n- 基于阈值过滤低置信度检测\n- 多线程实现并行处理\n\n## 5. 模型部署与效果\n\n### 5.1 部署架构\n\n所有的部署代码在 `deploy` 目录下，可以单独抽出来部署。注意，上面训练完成的模型，需要替换到 `deploy/backend/data` 目录下，当前代码使用的是 `c3d_nodule_detect.pth`。是前面训练过程中保留的最好的模型\n\n系统采用前后端分离架构:\n\n- **后端**: Flask RESTful API服务\n- **前端**: 基于HTML5和WebGL的交互式3D可视化界面\n\n### 5.2 用户界面\n\n- 支持多种CT数据格式上传。当前测试了MHD(上传时要将同一个病例的.raw和.mhd文件压缩为一个压缩文件再上传),没有测试DICOM\n- 结节位置标注和概率显示\n- 交互式浏览不同结节视图\n\n### 5.3 部署效果\n\n系统在本地环境中运行流畅，能够在数3分钟左右内完成单个CT扫描，推理一个CUBE仅需要0.2秒，主要是CUBE扫描比较耗时。结节检测和分类准确率达到临床应用参考水平，可为放射科医生提供辅助诊断支持。\n\n![部署效果](./deploy.png)"
  },
  {
    "path": "data/__init__.py",
    "content": ""
  },
  {
    "path": "data/dataclass/CTData.py",
    "content": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\nimport numpy as np\nimport SimpleITK as sitk\nfrom scipy import ndimage\nfrom enum import Enum\nimport matplotlib.pyplot as plt\nfrom util.dicom_util import load_dicom_slices, get_pixels_hu, get_dicom_thickness\nfrom util.seg_util import get_segmented_lungs, normalize_hu_values\n\n\nclass CTFormat(Enum):\n    DICOM = 1\n    MHD = 2\n    UNKNOWN = 3\n\nclass CTData:\n    \"\"\"\n    统一的CT数据类，用于处理不同格式的CT图像数据\n    支持DICOM和MHD格式的加载、处理和分析\n    \"\"\"\n    def __init__(self):\n        # 基本属性\n        self.pixel_data = None  # 像素数据，3D体素数组 (z,y,x)\n        self.lung_seg_img = None    # 单独抽取肺部CT图像数据\n        self.lung_seg_mask = None   # 肺部CT的掩码\n        self.origin = None  # 坐标原点 (x,y,z)，单位为mm\n        self.spacing = None  # 体素间距 (x,y,z)，单位为mm\n        self.orientation = None  # 方向矩阵\n        self.z_axis_flip = False    # z 轴是否是翻转的\n        self.size = None  # 图像尺寸 (z,y,x)\n        self.data_format = None  # 数据格式(DICOM/MHD)\n        self.metadata = {}  # 其他元数据信息\n        self.hu_converted = False  # 是否已转换为HU值\n        self.preprocessed = False   # 数据是否已经处理过\n\n    @classmethod\n    def from_dicom(cls, dicom_path):\n        \"\"\"\n        从DICOM文件夹加载CT数据\n\n        Args:\n            dicom_path: DICOM文件夹路径\n\n        Returns:\n            CTData对象\n        \"\"\"\n        ct_data = cls()\n        ct_data.data_format = CTFormat.DICOM\n        slices = load_dicom_slices(dicom_path)\n        ct_data.pixel_data = get_pixels_hu(slices)\n        ct_data.z_axis_flip = slices[1].ImagePositionPatient[2] > slices[0].ImagePositionPatient[2]\n        ct_data.hu_converted = True\n        slice_thickness = get_dicom_thickness(slices)\n        # 设置像素间距\n        try:\n            ct_data.spacing = [\n                float(slices[0].PixelSpacing[0]),\n                float(slices[0].PixelSpacing[1]),\n                float(slice_thickness)\n            ]\n        except:\n            print(\"警告: 无法获取像素间距，使用默认值[1.0, 1.0, 1.0]\")\n            ct_data.spacing = [1.0, 1.0, 1.0]\n        # 设置原点\n        try:\n            ct_data.origin = [\n                float(slices[0].ImagePositionPatient[0]),\n                float(slices[0].ImagePositionPatient[1]),\n                float(slices[0].ImagePositionPatient[2])\n            ]\n        except:\n            print(\"警告: 无法获取坐标原点，使用默认值[0.0, 0.0, 0.0]\")\n            ct_data.origin = [0.0, 0.0, 0.0]\n        # 设置尺寸\n        ct_data.size = ct_data.pixel_data.shape\n        return ct_data\n\n    @classmethod\n    def from_mhd(cls, mhd_path):\n        \"\"\"\n        从MHD/RAW文件加载CT数据\n\n        Args:\n            mhd_path: MHD文件路径\n\n        Returns:\n            CTData对象\n        \"\"\"\n        ct_data = cls()\n        ct_data.data_format = CTFormat.MHD\n        try:\n            # 使用SimpleITK加载MHD文件\n            itk_img = sitk.ReadImage(mhd_path)\n            # 获取像素数据 (注意SimpleITK返回的数组顺序为z,y,x)\n            ct_data.pixel_data = sitk.GetArrayFromImage(itk_img)\n            # LUNA16的MHD数据已经是HU值\n            ct_data.hu_converted = True\n            # 获取原点和体素间距\n            ct_data.origin = list(itk_img.GetOrigin())  # (x,y,z)\n            ct_data.spacing = list(itk_img.GetSpacing())  # (x,y,z)\n            # 获取尺寸\n            ct_data.size = ct_data.pixel_data.shape\n            # 提取方向信息\n            ct_data.orientation = itk_img.GetDirection()\n            ct_data.z_axis_flip = False\n        except Exception as e:\n            raise ValueError(f\"加载MHD文件时出错: {e}\")\n\n        return ct_data\n\n    def convert_to_hu(self):\n        \"\"\"\n        将像素值转换为HU值（如果尚未转换）\n        \"\"\"\n        if self.hu_converted:\n            print(\"数据已经是HU值格式\")\n            return\n\n        if self.data_format == CTFormat.DICOM:\n            # 已在from_dicom中处理\n            self.hu_converted = True\n        elif self.data_format == CTFormat.MHD:\n            # LUNA16的MHD数据已经是HU值\n            self.hu_converted = True\n        else:\n            raise ValueError(\"未知数据格式，无法转换为HU值\")\n\n\n    def resample_pixel(self, new_spacing=[1, 1, 1]):\n        \"\"\"\n        将CT体素重采样为指定间距\n\n        Args:\n            new_spacing: 目标体素间距 [x, y, z]\n\n        Returns:\n            重采样后的CTData对象\n        \"\"\"\n        # 确保数据已转换为HU值\n        if not self.hu_converted:\n            self.convert_to_hu()\n        # 为了符合scipy.ndimage的要求，将spacing和pixel_data的顺序调整为[z,y,x]\n        spacing_zyx = [self.spacing[2], self.spacing[1], self.spacing[0]]\n        new_spacing_zyx = [new_spacing[2], new_spacing[1], new_spacing[0]]\n        # 计算新尺寸\n        resize_factor = np.array(spacing_zyx) / np.array(new_spacing_zyx)\n        new_shape = np.round(np.array(self.pixel_data.shape) * resize_factor)\n        # 计算实际重采样因子\n        real_resize = new_shape / np.array(self.pixel_data.shape)\n        # 执行重采样 - 使用三线性插值\n        resampled_data = ndimage.zoom(self.pixel_data, real_resize, order=1)\n        # 创建新的CTData对象\n        resampled_ct = CTData()\n        resampled_ct.pixel_data = resampled_data\n        resampled_ct.spacing = new_spacing\n        resampled_ct.origin = self.origin\n        resampled_ct.orientation = self.orientation\n        resampled_ct.size = resampled_data.shape\n        resampled_ct.data_format = self.data_format\n        resampled_ct.hu_converted = self.hu_converted\n        resampled_ct.preprocessed = self.preprocessed\n        return resampled_ct\n\n    def filter_lung_img_mask(self):\n        \"\"\"\n            只保留肺部区域像素，并且归一化到 0-1 之间\n        :return:\n        \"\"\"\n        pixel_data = self.pixel_data.copy()\n        seg_img = []\n        seg_mask = []\n        for index in range(pixel_data.shape[0]):\n            one_seg_img ,one_seg_mask = get_segmented_lungs(pixel_data[index])\n            one_seg_img = normalize_hu_values(one_seg_img)\n            seg_img.append(one_seg_img)\n            seg_mask.append(one_seg_mask)\n        self.lung_seg_img = np.array(seg_img)\n        self.lung_seg_mask = np.array(seg_mask)\n    def world_to_voxel(self, world_coord):\n        \"\"\"\n        将世界坐标(mm)转换为体素坐标\n\n        Args:\n            world_coord: 世界坐标 [x,y,z] (mm)\n\n        Returns:\n            体素坐标 [x,y,z]\n        \"\"\"\n        voxel_coord = np.zeros(3, dtype=int)\n        for i in range(3):\n            voxel_coord[i] = int(round((world_coord[i] - self.origin[i]) / self.spacing[i]))\n\n        return voxel_coord\n\n    def voxel_to_world(self, voxel_coord):\n        \"\"\"\n        将体素坐标转换为世界坐标(mm)\n\n        Args:\n            voxel_coord: 体素坐标 [x,y,z]\n\n        Returns:\n            世界坐标 [x,y,z] (mm)\n        \"\"\"\n        world_coord = np.zeros(3, dtype=float)\n        for i in range(3):\n            world_coord[i] = voxel_coord[i] * self.spacing[i] + self.origin[i]\n\n        return world_coord\n\n    def extract_cube(self, center_world_mm, size_mm,if_fixed_radius = False):\n        \"\"\"\n        提取指定中心点和大小的立方体区域\n\n        Args:\n            center_world_mm:    立方体中心的世界坐标 [x,y,z] (mm)\n            size_mm:            立方体在世界坐标系的大小(mm)，可以是数值或[x,y,z]形式\n            if_fixed_radius:    是否为固定半径。默认是False(即不不是固定的，就说明每个结节半径都不一样，按照标注文件半径抽取)\n\n        Returns:\n            立方体像素数据\n        \"\"\"\n        # 确保数据已加载\n        if self.pixel_data is None:\n            raise ValueError(\"未加载数据\")\n        if self.lung_seg_img is None:\n            print(\"肺部区域数据没有分割，现在开始分割..\")\n            self.filter_lung_img_mask()\n        # 将世界坐标转换为体素坐标（注意：SimpleITK数组顺序为z,y,x）\n        center_voxel = self.world_to_voxel(center_world_mm)\n        # 交换坐标顺序为z,y,x以匹配pixel_data\n        center_voxel_zyx = [center_voxel[2], center_voxel[1], center_voxel[0]]\n        # 如果使用固定半径，那么只需要中心坐标即可，此时size_mm 就是像素半径了，直接从 lung_seg_img 按照像素半径抽取即可\n        if if_fixed_radius:\n            half_size = [int(size_mm/2), int(size_mm/2), int(size_mm/2)]\n        else:\n            # 计算立方体边长(体素数) [luna2016 的标注数据中每个结节半径不同，按照标注抽取的结节大小不一，最好使用固定半径]\n            size_voxel = [int(size_mm / self.spacing[2]),\n                          int(size_mm / self.spacing[1]),\n                          int(size_mm / self.spacing[0])]\n            # 计算立方体边界\n            half_size = [s // 2 for s in size_voxel]\n        # 提取立方体数据\n        z_min = max(0, center_voxel_zyx[0] - half_size[0])\n        y_min = max(0, center_voxel_zyx[1] - half_size[1])\n        x_min = max(0, center_voxel_zyx[2] - half_size[2])\n\n        z_max = min(self.lung_seg_img.shape[0], center_voxel_zyx[0] + half_size[0])\n        y_max = min(self.lung_seg_img.shape[1], center_voxel_zyx[1] + half_size[1])\n        x_max = min(self.lung_seg_img.shape[2], center_voxel_zyx[2] + half_size[2])\n        # 提取子体积\n        cube = self.lung_seg_img[z_min:z_max, y_min:y_max, x_min:x_max]\n        return cube\n\n    def visualize_slice(self, slice_idx=None, axis=0, show_lung_only=False):\n        \"\"\"\n            可视化单个切片\n        Args:\n            slice_idx:          切片索引，如果为None则取中心切片\n            axis:               沿哪个轴切片 (0=z, 1=y, 2=x)\n            show_lung_only:     是否只显示肺部，其他区域都作为背景黑色\n        \"\"\"\n        # 确保数据已加载\n        if self.pixel_data is None:\n            raise ValueError(\"未加载数据\")\n        # 确定切片索引\n        if slice_idx is None:\n            slice_idx = self.pixel_data.shape[axis] // 2\n        # 提取切片数据\n        if show_lung_only:\n            if axis == 0:  # z轴\n                slice_data = self.lung_seg_img[slice_idx, :, :]\n            elif axis == 1:  # y轴\n                slice_data = self.lung_seg_img[:, slice_idx, :]\n            else :  # x轴\n                slice_data = self.lung_seg_img[:, :, slice_idx]\n        else:\n            if axis == 0:  # z轴\n                slice_data = self.pixel_data[slice_idx, :, :]\n            elif axis == 1:  # y轴\n                slice_data = self.pixel_data[:, slice_idx, :]\n            else:  # x轴\n                slice_data = self.pixel_data[:, :, slice_idx]\n        # 创建图像\n        plt.figure(figsize=(10, 8))\n        # 仅显示图像\n        plt.imshow(slice_data, cmap='gray')\n        # 设置标题\n        axis_name = ['z', 'y', 'x'][axis]\n        title = f\"切片 {slice_idx} (沿{axis_name}轴)\"\n        plt.title(title)\n        plt.colorbar(label='像素值')\n        plt.axis('off')\n        plt.tight_layout()\n        plt.show()\n\n    def visualize_nodule(self, coord_x,coord_y, coord_z, diameter):\n        \"\"\"\n             结节可视化\n        :param coord_x:\n        :param coord_y:\n        :param coord_z:\n        :param diameter:\n        :return:\n        \"\"\"\n        # 提取结节立方体\n        cube_size = max(32, int(diameter * 1.5))  # 确保立方体足够大\n        cube = self.extract_cube([coord_x, coord_y, coord_z], cube_size)\n        # 转换为体素坐标\n        voxel_coord = self.world_to_voxel([coord_x, coord_y, coord_z])\n        # 显示三个正交面\n        fig, axes = plt.subplots(1, 3, figsize=(15, 5))\n        # 提取中心切片\n        center_z = cube.shape[0] // 2\n        center_y = cube.shape[1] // 2\n        center_x = cube.shape[2] // 2\n        # 绘制三个正交面\n        axes[0].imshow(cube[center_z, :, :], cmap='gray')\n        axes[0].set_title(f'轴向视图 (z={center_z})')\n        axes[0].axis('off')\n        axes[1].imshow(cube[:, center_y, :], cmap='gray')\n        axes[1].set_title(f'冠状位视图 (y={center_y})')\n        axes[1].axis('off')\n        axes[2].imshow(cube[:, :, center_x], cmap='gray')\n        axes[2].set_title(f'矢状位视图 (x={center_x})')\n        axes[2].axis('off')\n        fig.suptitle(f\"结节- 位置: ({coord_x:.1f}, {coord_y:.1f}, {coord_z:.1f})mm, \" +\n                     f\"直径: {diameter:.1f}mm,\", fontsize=14)\n        plt.tight_layout()\n        plt.show()\n\n    def save_as_nifti(self, output_path):\n        \"\"\"\n        将CT数据保存为NIfTI格式\n\n        Args:\n            output_path: 输出文件路径\n        \"\"\"\n        # 确保数据已加载\n        if self.pixel_data is None:\n            raise ValueError(\"未加载数据\")\n\n        # 创建SimpleITK图像\n        # 注意：SimpleITK的数组顺序为z,y,x\n        img = sitk.GetImageFromArray(self.pixel_data)\n        img.SetOrigin(self.origin)\n        img.SetSpacing(self.spacing)\n\n        if self.orientation is not None:\n            img.SetDirection(self.orientation)\n        # 保存为NIfTI格式\n        sitk.WriteImage(img, output_path)\n        print(f\"已保存为NIfTI格式: {output_path}\")\n\n"
  },
  {
    "path": "data/dataclass/NoduleCube.py",
    "content": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\nimport os,torch\nimport numpy as np\nimport cv2\nfrom typing import  Optional\nfrom dataclasses import dataclass\nimport matplotlib.pyplot as plt\nfrom scipy import ndimage\n\ndef normal_cube_to_tensor(cube_data):\n    \"\"\"\n        将cube 数据归一化并转换为 pytorch tensor 。 用在训练和推理过程\n    :param cube_data: shape为 [32,32,32] 的 ndarray\n    :return:\n    \"\"\"\n    cube_data = cube_data.astype(np.float32)\n    # 归一化到 [0, 1] 范围\n    min_val = np.min(cube_data)\n    max_val = np.max(cube_data)\n    data_range = max_val - min_val\n    # 避免除以零\n    if data_range < 1e-10:\n        normalized_cube = np.zeros_like(cube_data)\n    else:\n        normalized_cube = (cube_data - min_val) / data_range\n    # 检查是否有无效值并修复\n    if np.isnan(normalized_cube).any() or np.isinf(normalized_cube).any():\n        normalized_cube = np.nan_to_num(normalized_cube, nan=0.0, posinf=1.0, neginf=0.0)\n    # 转换为PyTorch张量并添加批次和通道维度\n    cube_tensor = torch.from_numpy(normalized_cube).float().unsqueeze(0)  # (1, 1, 32, 32, 32)\n    return cube_tensor\n\n\n@dataclass\nclass NoduleCube:\n    \"\"\"\n    肺结节立方体类，表示肺结节区域的3D立方体数据\n    与CT数据无关，仅处理已提取的立方体数据\n    \"\"\"\n    # 基本属性\n    cube_size: int = 64  # 立方体大小（默认64x64x64）\n    pixel_data: Optional[np.ndarray] = None  # 像素数据 shape: [cube_size, cube_size, cube_size]\n    \n    # 结节特征\n    center_x: int = 0  # 结节中心x坐标\n    center_y: int = 0  # 结节中心y坐标\n    center_z: int = 0  # 结节中心z坐标\n    radius: float = 0.0  # 结节半径\n    malignancy: int = 0  # 恶性度 (0 为良性 / 1 为恶性)\n    \n    # 文件路径\n    npy_path: str = \"\"  # npy文件路径\n    png_path: str = \"\"  # png文件路径\n\n    def __post_init__(self):\n        \"\"\"初始化后调用\"\"\"\n        # 如果提供了npy_path但没有pixel_data，尝试加载\n        if self.npy_path and self.pixel_data is None:\n            self.load_from_npy()\n        # 如果提供了png_path但没有pixel_data，尝试加载\n        elif self.png_path and self.pixel_data is None:\n            self.load_from_png()\n\n    def load_from_npy(self) -> None:\n        \"\"\"从NPY文件加载立方体数据\"\"\"\n        if not os.path.exists(self.npy_path):\n            raise FileNotFoundError(f\"文件不存在: {self.npy_path}\")\n            \n        try:\n            self.pixel_data = np.load(self.npy_path)\n            # 验证尺寸\n            if len(self.pixel_data.shape) != 3:\n                raise ValueError(f\"像素数据必须是3D数组，当前形状: {self.pixel_data.shape}\")\n            \n            # 如果尺寸不匹配，调整大小\n            if (self.pixel_data.shape[0] != self.cube_size or \n                self.pixel_data.shape[1] != self.cube_size or \n                self.pixel_data.shape[2] != self.cube_size):\n                self.resize(self.cube_size)\n                \n        except Exception as e:\n            raise ValueError(f\"加载NPY文件时出错: {e}\")\n\n    def save_to_npy(self, output_path: str) -> str:\n        \"\"\"\n        将立方体数据保存为NPY文件\n        \n        Args:\n            output_path: 输出路径\n            \n        Returns:\n            保存的文件路径\n        \"\"\"\n        if self.pixel_data is None:\n            raise ValueError(\"没有像素数据可保存\")\n            \n        # 确保目录存在\n        os.makedirs(os.path.dirname(output_path), exist_ok=True)\n        np.save(output_path, self.pixel_data)\n        self.npy_path = output_path\n        return output_path\n    def save_to_png(self, output_path: str) -> str:\n        \"\"\"\n        将立方体数据保存为PNG图像（8x8网格布局）\n        \n        Args:\n            output_path: 输出PNG文件路径\n            \n        Returns:\n            保存的文件路径\n        \"\"\"\n        if self.pixel_data is None:\n            raise ValueError(\"没有像素数据可保存\")\n            \n        # 确保目录存在\n        os.makedirs(os.path.dirname(output_path), exist_ok=True)\n        \n        # 计算每个切片在最终图像中的位置（8行8列布局）\n        rows, cols = 8, 8\n        if self.cube_size != 64:\n            # 如果不是64x64x64，计算合适的行列数，保持接近正方形\n            total_slices = self.cube_size\n            rows = int(np.sqrt(total_slices))\n            while total_slices % rows != 0:\n                rows -= 1\n            cols = total_slices // rows\n        \n        # 创建拼接图像\n        img_height = self.cube_size\n        img_width = self.cube_size\n        combined_img = np.zeros((rows * img_height, cols * img_width), dtype=np.uint8)\n        \n        # 填充拼接图像\n        for i in range(self.cube_size):\n            row = i // cols\n            col = i % cols\n            \n            slice_data = self.pixel_data[i]\n            \n            # 确保数据在0-255范围内\n            if slice_data.max() <= 1.0:\n                slice_data = (slice_data * 255).astype(np.uint8)\n            else:\n                slice_data = slice_data.astype(np.uint8)\n            \n            # 将切片放入拼接图像\n            y_start = row * img_height\n            x_start = col * img_width\n            combined_img[y_start:y_start + img_height, x_start:x_start + img_width] = slice_data\n        \n        # 保存拼接图像\n        cv2.imwrite(output_path, combined_img)\n        self.png_path = output_path\n        return output_path\n\n    def load_from_png(self) -> None:\n        \"\"\"从PNG图像加载立方体数据（8x8网格布局）\"\"\"\n        if not os.path.exists(self.png_path):\n            raise FileNotFoundError(f\"文件不存在: {self.png_path}\")\n            \n        try:\n            # 读取PNG图像\n            img = cv2.imread(self.png_path, cv2.IMREAD_GRAYSCALE)\n            \n            # 确定行列数\n            rows, cols = 8, 8\n            if self.cube_size != 64:\n                # 如果不是64x64x64，计算合适的行列数\n                total_slices = self.cube_size\n                rows = int(np.sqrt(total_slices))\n                while total_slices % rows != 0:\n                    rows -= 1\n                cols = total_slices // rows\n            \n            # 确认图像尺寸正确\n            expected_height = rows * self.cube_size\n            expected_width = cols * self.cube_size\n            if img.shape[0] != expected_height or img.shape[1] != expected_width:\n                raise ValueError(f\"图像尺寸不匹配: 期望{expected_height}x{expected_width}, 实际{img.shape[0]}x{img.shape[1]}\")\n            \n            # 创建3D数组\n            cube_data = np.zeros((self.cube_size, self.cube_size, self.cube_size), dtype=np.float32)\n            \n            # 从PNG图像提取每个切片\n            for i in range(self.cube_size):\n                row = i // cols\n                col = i % cols\n                \n                y_start = row * self.cube_size\n                x_start = col * self.cube_size\n                \n                slice_data = img[y_start:y_start + self.cube_size, x_start:x_start + self.cube_size]\n                cube_data[i] = slice_data.astype(np.float32) / 255.0  # 归一化到[0,1]范围\n            \n            self.pixel_data = cube_data\n            \n        except Exception as e:\n            raise ValueError(f\"加载PNG文件时出错: {e}\")\n\n    def set_cube_data(self, pixel_data: np.ndarray) -> None:\n        \"\"\"\n        设置立方体像素数据\n        \n        Args:\n            pixel_data: 3D像素数据\n        \"\"\"\n        if len(pixel_data.shape) != 3:\n            raise ValueError(f\"像素数据必须是3D数组，当前形状: {pixel_data.shape}\")\n        \n        self.pixel_data = pixel_data\n        \n        # 如果尺寸不匹配，调整大小\n        if (self.pixel_data.shape[0] != self.cube_size or \n            self.pixel_data.shape[1] != self.cube_size or \n            self.pixel_data.shape[2] != self.cube_size):\n            self.resize(self.cube_size)\n\n    def resize(self, new_size: int) -> None:\n        \"\"\"\n        调整立方体尺寸\n        \n        Args:\n            new_size: 新的立方体尺寸\n        \"\"\"\n        if self.pixel_data is None:\n            raise ValueError(\"没有像素数据可调整大小\")\n        \n        # 计算缩放因子\n        zoom_factors = [new_size / self.pixel_data.shape[0],\n                         new_size / self.pixel_data.shape[1],\n                         new_size / self.pixel_data.shape[2]]\n        \n        # 使用scipy的ndimage进行重采样\n        self.pixel_data = ndimage.zoom(self.pixel_data, zoom_factors, mode='nearest')\n        self.cube_size = new_size\n        \n    def augment(self, rotation: bool = True, flip_axis: int = -1, noise: bool = True) -> 'NoduleCube':\n        \"\"\"\n        数据增强\n        \n        Args:\n            rotation: 是否进行旋转增强\n            flip_axis: 是否进行翻转增强，默认为-1（不翻转）\n            noise: 是否添加噪声\n            \n        Returns:\n            增强后的新立方体实例\n        \"\"\"\n        if self.pixel_data is None:\n            raise ValueError(\"没有像素数据可增强\")\n        \n        # 创建副本\n        augmented_cube = self.pixel_data.copy()\n        \n        # 旋转增强\n        if rotation:\n            # 随机选择旋转角度\n            angles = np.random.uniform(-20, 20, 3)  # 在xyz三个方向上随机旋转\n            augmented_cube = ndimage.rotate(augmented_cube, angles[0], axes=(1, 2), reshape=False, mode='nearest')\n            augmented_cube = ndimage.rotate(augmented_cube, angles[1], axes=(0, 2), reshape=False, mode='nearest')\n            augmented_cube = ndimage.rotate(augmented_cube, angles[2], axes=(0, 1), reshape=False, mode='nearest')\n        \n        # 翻转增强\n        if flip_axis >=0:\n            augmented_cube = np.flip(augmented_cube, axis=flip_axis)\n        \n        # 添加噪声\n        if noise:\n            # 添加随机高斯噪声\n            noise_level = np.random.uniform(0.0, 0.05)\n            noise_array = np.random.normal(0, noise_level, augmented_cube.shape)\n            augmented_cube = augmented_cube + noise_array\n            # 确保值在[0,1]范围内\n            augmented_cube = np.clip(augmented_cube, 0, 1)\n        \n        # 创建新实例\n        new_cube = NoduleCube(\n            cube_size=self.cube_size,\n            center_x=self.center_x,\n            center_y=self.center_y,\n            center_z=self.center_z,\n            radius=self.radius,\n            malignancy=self.malignancy\n        )\n        \n        new_cube.set_cube_data(augmented_cube)\n        return new_cube\n\n    def visualize_3d(self, output_path: Optional[str] = None, show: bool = True) -> None:\n        \"\"\"\n        可视化立方体数据\n        \n        Args:\n            output_path: 可选的输出路径，如果提供则保存图像\n            show: 是否显示图像\n        \"\"\"\n        if self.pixel_data is None:\n            raise ValueError(\"没有像素数据可视化\")\n            \n        # 创建图像\n        fig, axes = plt.subplots(2, 3, figsize=(12, 8))\n        \n        # 获取中心切片\n        center_z = self.pixel_data.shape[0] // 2\n        center_y = self.pixel_data.shape[1] // 2\n        center_x = self.pixel_data.shape[2] // 2\n        \n        # 显示三个正交平面\n        slice_xy = self.pixel_data[center_z, :, :]\n        slice_xz = self.pixel_data[:, center_y, :]\n        slice_yz = self.pixel_data[:, :, center_x]\n        \n        # 显示三个正交视图\n        axes[0, 0].imshow(slice_xy, cmap='gray')\n        axes[0, 0].set_title(f'轴向视图 (Z={center_z})')\n        \n        axes[0, 1].imshow(slice_xz, cmap='gray')\n        axes[0, 1].set_title(f'矢状位视图 (Y={center_y})')\n        \n        axes[0, 2].imshow(slice_yz, cmap='gray')\n        axes[0, 2].set_title(f'冠状位视图 (X={center_x})')\n        \n        # 3D渲染视图（使用MIP: Maximum Intensity Projection）\n        mip_xy = np.max(self.pixel_data, axis=0)\n        mip_xz = np.max(self.pixel_data, axis=1)\n        mip_yz = np.max(self.pixel_data, axis=2)\n        \n        axes[1, 0].imshow(mip_xy, cmap='gray')\n        axes[1, 0].set_title('最大强度投影 (轴向)')\n        \n        axes[1, 1].imshow(mip_xz, cmap='gray')\n        axes[1, 1].set_title('最大强度投影 (矢状位)')\n        \n        axes[1, 2].imshow(mip_yz, cmap='gray')\n        axes[1, 2].set_title('最大强度投影 (冠状位)')\n        \n        # 添加结节信息\n        nodule_info = f\"结节中心: ({self.center_x}, {self.center_y}, {self.center_z})\\n\"\n        nodule_info += f\"半径: {self.radius:.1f}\\n\"\n        nodule_info += f\"恶性度: {'恶性' if self.malignancy == 1 else '良性'}\"\n        \n        fig.suptitle(nodule_info, fontsize=12)\n        plt.tight_layout()\n        \n        if output_path:\n            plt.savefig(output_path, dpi=200, bbox_inches='tight')\n        \n        if show:\n            plt.show()\n        else:\n            plt.close(fig)\n            \n    @classmethod\n    def from_npy(cls, file_path: str, cube_size: int = 64) -> 'NoduleCube':\n        \"\"\"\n        从NPY文件创建立方体实例\n        \n        Args:\n            file_path: NPY文件路径\n            cube_size: 立方体大小\n            \n        Returns:\n            NoduleCube实例\n        \"\"\"\n        cube = cls(cube_size=cube_size, npy_path=file_path)\n        cube.load_from_npy()\n        return cube\n    \n    @classmethod\n    def from_png(cls, file_path: str, cube_size: int = 64) -> 'NoduleCube':\n        \"\"\"\n        从PNG文件创建立方体实例\n        \n        Args:\n            file_path: PNG文件路径\n            cube_size: 立方体大小\n            \n        Returns:\n            NoduleCube实例\n        \"\"\"\n        cube = cls(cube_size=cube_size, png_path=file_path)\n        cube.load_from_png()\n        return cube\n        \n    @classmethod\n    def from_array(cls, \n                  pixel_data: np.ndarray, \n                  center_x: int = 0, \n                  center_y: int = 0, \n                  center_z: int = 0,\n                  radius: float = 0.0,\n                  malignancy: int = 0) -> 'NoduleCube':\n        \"\"\"\n        从numpy数组创建立方体实例\n        \n        Args:\n            pixel_data: 3D像素数据\n            center_x: 中心点X坐标\n            center_y: 中心点Y坐标\n            center_z: 中心点Z坐标\n            radius: 结节半径\n            malignancy: 恶性度(0=良性, 1=恶性)\n            \n        Returns:\n            NoduleCube实例\n        \"\"\"\n        if len(pixel_data.shape) != 3:\n            raise ValueError(f\"像素数据必须是3D数组，当前形状: {pixel_data.shape}\")\n            \n        cube_size = pixel_data.shape[0]\n        if pixel_data.shape[1] != cube_size or pixel_data.shape[2] != cube_size:\n            raise ValueError(f\"像素数据必须是立方体形状，当前形状: {pixel_data.shape}\")\n            \n        cube = cls(\n            cube_size=cube_size,\n            center_x=center_x,\n            center_y=center_y,\n            center_z=center_z,\n            radius=radius,\n            malignancy=malignancy\n        )\n        \n        cube.set_cube_data(pixel_data)\n        return cube\n\n\n"
  },
  {
    "path": "data/dataclass/__init__.py",
    "content": "\n"
  },
  {
    "path": "data/preprocessing/__init__.py",
    "content": ""
  },
  {
    "path": "data/preprocessing/lidc_process/README.md",
    "content": "# LIDC-IDRI 数据预处理\n\n本目录包含处理 [LIDC-IDRI (Lung Image Database Consortium and Image Database Resource Initiative)](https://www.cancerimagingarchive.net/collection/lidc-idri/) 数据集的预处理脚本。LIDC-IDRI 是一个公开的肺部CT扫描数据集，包含了1000多个患者的胸部CT扫描和相应的医生标注。\n\n## LIDC-IDRI 数据集详情\n\nLIDC-IDRI 数据集由7个学术中心和8个医学影像公司合作创建，共包含1018个案例。每个案例包括临床胸部CT扫描图像和由四位有经验的胸部放射科医生进行的标注结果。标注过程分为两个阶段：\n\n1. **盲审阶段**：每位放射科医生独立审查每个CT扫描，并标记属于三个类别之一的病变：\n   - 直径≥3mm的结节\n   - 直径<3mm的结节\n   - 直径≥3mm的非结节\n\n2. **非盲审阶段**：每位放射科医生独立审查自己的标记以及其他三位放射科医生的匿名标记，以形成最终意见。\n\n### 数据组成\n\n- **CT扫描**：1018个病例的胸部CT扫描DICOM文件\n- **标注XML文件**：包含结节位置、大小和特征的XML格式标注\n- **结节诊断信息**：包含结节的恶性度评分和其他特征\n\n### 标注特征说明\n\n每个标记的结节包含以下特征评分（1-5分）：\n\n| 特征名称 | 评分范围 | 含义 |\n| ------- | ------- | ---- |\n| 恶性度(malignancy) | 1-5 | 1=高度良性，5=高度恶性 |\n| 球形度(sphericity) | 1-5 | 1=线性，5=完全球形 |\n| 边缘特征(margin) | 1-5 | 1=明显，5=模糊 |\n| 毛刺(spiculation) | 1-5 | 1=无毛刺，5=明显毛刺 |\n| 纹理(texture) | 1-5 | 1=非实性，5=实性 |\n| 钙化(calcification) | 1-6 | 不同类型的钙化 |\n| 内部结构(internal structure) | 1-4 | 不同类型的内部结构 |\n| 分叶性(lobulation) | 1-5 | 1=无分叶，5=明显分叶 |\n| 细微性(subtlety) | 1-5 | 1=明显，5=细微 |\n\n## 数据示例\n\n### XML标注示例\n\n```xml\n<LidcReadMessage>\n  <ResponseHeader>\n    <SeriesInstanceUid>1.3.6.1.4.1.14519.5.2.1.6279.6001.123456789</SeriesInstanceUid>\n  </ResponseHeader>\n  <readingSession>\n    <servicingRadiologistID>Reader1</servicingRadiologistID>\n    <unblindedReadNodule>\n      <noduleID>Nodule001</noduleID>\n      <characteristics>\n        <malignancy>4</malignancy>\n        <sphericity>5</sphericity>\n        <margin>4</margin>\n        <spiculation>3</spiculation>\n        <texture>5</texture>\n        <calcification>1</calcification>\n        <internalStructure>1</internalStructure>\n        <lobulation>2</lobulation>\n        <subtlety>3</subtlety>\n      </characteristics>\n      <roi>\n        <imageZposition>-124.0</imageZposition>\n        <edgeMap>\n          <xCoord>256</xCoord>\n          <yCoord>215</yCoord>\n        </edgeMap>\n        <!-- 更多边缘点... -->\n      </roi>\n      <!-- 更多ROI... -->\n    </unblindedReadNodule>\n    <nonNodule>\n      <nonNoduleID>NonNodule001</nonNoduleID>\n      <imageZposition>-134.0</imageZposition>\n      <locus>\n        <xCoord>345</xCoord>\n        <yCoord>287</yCoord>\n      </locus>\n    </nonNodule>\n  </readingSession>\n  <!-- 更多readingSession... -->\n</LidcReadMessage>\n```\n\n### 处理后的CSV数据示例\n\n**百分比坐标CSV（process_lidc_annotations输出）**：\n\n```\npatient_id,anno_index,servicingRadiologistID,coord_x,coord_y,coord_z,diameter,malscore,sphericiy,margin,spiculation,texture,calcification,internal_structure,lobulation,subtlety\n1.3.6.1.4.1.14519.5.2.1.6279.6001.123456789,Nodule001,Reader1,0.5242,0.4455,0.3789,0.0521,4,5,4,3,5,1,1,2,3\n1.3.6.1.4.1.14519.5.2.1.6279.6001.123456789,Nodule001,Reader2,0.5256,0.4478,0.3802,0.0534,3,4,3,2,5,1,1,3,2\n```\n\n**毫米坐标CSV（percent_coordinatecsv_to_mmcsv输出）**：\n\n```\npatient_id,anno_index,servicingRadiologistID,coord_x,coord_y,coord_z,mm_x,mm_y,mm_z,diameter,malscore,sphericiy,margin,spiculation,texture,calcification,internal_structure,lobulation,subtlety\n1.3.6.1.4.1.14519.5.2.1.6279.6001.123456789,Nodule001,Reader1,0.5242,0.4455,0.3789,126.5,107.8,-124.0,0.0521,4,5,4,3,5,1,1,2,3\n1.3.6.1.4.1.14519.5.2.1.6279.6001.123456789,Nodule001,Reader2,0.5256,0.4478,0.3802,127.0,108.5,-124.2,0.0534,3,4,3,2,5,1,1,3,2\n```\n\n**带平均坐标和恶性度标签的CSV（最终输出）**：\n\n```\npatient_id,anno_index,servicingRadiologistID,coord_x,coord_y,coord_z,mm_x,mm_y,mm_z,avg_x,avg_y,avg_z,diameter,malscore,real_mal,sphericiy,margin,spiculation,texture,calcification,internal_structure,lobulation,subtlety\n1.3.6.1.4.1.14519.5.2.1.6279.6001.123456789,Nodule001,Reader1,0.5242,0.4455,0.3789,126.5,107.8,-124.0,126.75,108.15,-124.1,0.0521,4,1,5,4,3,5,1,1,2,3\n1.3.6.1.4.1.14519.5.2.1.6279.6001.123456789,Nodule001,Reader2,0.5256,0.4478,0.3802,127.0,108.5,-124.2,126.75,108.15,-124.1,0.0534,3,1,4,3,2,5,1,1,3,2\n```\n\n### 结节标注统计\n\n在LIDC-IDRI数据集中：\n- 共有1018个病例\n- 约2669个被至少一位放射科医生标注的≥3mm结节\n- 约928个被所有四位放射科医生标注的≥3mm结节\n- 约479个被标记为恶性的结节（平均恶性度评分>3）\n- 约591个被标记为良性的结节（平均恶性度评分<3）\n- 约858个具有不确定恶性度的结节（平均恶性度评分=3）\n\n## 数据处理流程\n\n数据处理分为以下几个主要步骤：\n\n1. 从原始XML标注文件中提取结节信息\n2. 将原始百分比坐标转换为毫米坐标\n3. 汇总和整合多位放射科医生的结节标注\n4. 计算结节的恶性度标签\n5. 生成用于模型训练的数据集\n\n## 脚本说明\n\n### 1. lidc_annotation_process.py\n\n该脚本用于处理LIDC-IDRI数据集中的XML标注文件，提取结节的位置、大小和特征信息。\n\n**主要功能**：\n\n- `read_nodule_annotation_from_xml()`: 读取单个XML标注文件，提取结节信息\n- `process_lidc_annotations()`: 处理所有XML标注文件，汇总所有结节信息\n- `extract_lidc_every_z_annotations()`: 提取每个Z轴切片上的标注信息\n- `merge_nodule_annotation_csv_to_one()`: 将多个结节标注CSV文件合并为一个\n\n**处理的信息包括**：\n\n- 结节的位置坐标（百分比形式）\n- 结节直径\n- 结节恶性度评分（1-5）\n- 其他特征：球形度、边缘特征、毛刺、纹理、钙化、内部结构、分叶性和细微性\n\n### 2. lidc_coordinate_process.py\n\n该脚本处理由 `lidc_annotation_process.py` 生成的结果，主要关注标注坐标的转换和处理。\n\n**主要功能**：\n\n- `percent_coordinatecsv_to_mmcsv()`: 将百分比坐标转换为毫米坐标\n- `avg_coordinates()`: 计算相同结节的平均坐标\n- `add_final_mals()`: 计算每个结节的最终恶性度标签\n- `draw_percent_cube_by_csv()`: 根据百分比坐标绘制立方体区域\n- `draw_all_confirmed_cubes()`: 绘制所有确认的结节立方体\n\n## 处理流程\n\n1. **读取原始XML标注**:\n   - 解析XML文件提取每个放射科医生的标注\n   - 获取结节的位置坐标（以图像百分比表示）\n   - 提取结节的特征信息（恶性度评分、纹理等）\n\n2. **坐标转换**:\n   - 将百分比坐标转换为毫米坐标\n   - 根据DICOM元数据计算实际物理位置\n\n3. **结节汇总**:\n   - 同一结节可能被多位放射科医生标注\n   - 计算相近结节的平均坐标\n   - 合并多位医生对同一结节的标注\n\n4. **恶性度计算**:\n   - 每个结节由多位医生评分（1-5分）\n   - 根据各评分计算最终恶性度标签\n   - 0 = 良性，1 = 恶性，\"unknow\" = 不确定\n\n5. **数据可视化**:\n   - 在原始CT图像上绘制结节立方体\n   - 用于验证标注准确性\n\n## 使用示例\n\n```python\n# 处理XML标注\nprocess_lidc_annotations(\"path/to/xml/*.xml\", patient_mhd_path_dict, \"path/to/save/all_annotations.csv\")\n\n# 转换坐标\npercent_coordinatecsv_to_mmcsv(\"all_annotations.csv\", mhd_info_csv, \"mm_coordinates.csv\")\n\n# 计算平均坐标\navg_coordinates(\"mm_coordinates.csv\", threshold=5, \"avg_coordinates.csv\")\n\n# 添加最终恶性度标签\nadd_final_mals(\"avg_coordinates.csv\", \"final_annotations.csv\")\n```\n\n## 数据集参考\n\nLIDC-IDRI数据集：[https://www.cancerimagingarchive.net/collection/lidc-idri/](https://www.cancerimagingarchive.net/collection/lidc-idri/) "
  },
  {
    "path": "data/preprocessing/lidc_process/__init__.py",
    "content": ""
  },
  {
    "path": "data/preprocessing/lidc_process/lidc_annotation_process.py",
    "content": "import pandas as pd\nimport math\nimport glob\nimport os\n\nfrom bs4 import BeautifulSoup\nfrom constant import luna\n\n# mhd  or dicom file information csv file path\nmhd_info_csv = luna.MHD_INFO_CSV\n# csv columns name of postive nodule\npos_annotation_csv_head = [\"anno_index\", \"coord_x\", \"coord_y\", \"coord_z\", \"diameter\", \"malscore\"]\n# csv columns name of negative nodule\nneg_annotation_csv_head = [\"anno_index\", \"coord_x\", \"coord_y\", \"coord_z\", \"diameter\", \"malscore\"]\n# root path of every patient's annotation information extracted from xml file\nextracted_annotation_info_root_path ='/data/LUNA2016/extracted_annotation_infos/'\n#columns name of  all information extracted from xml annotations\nall_annotation_csv_head = [\"patient_id\", \"anno_index\",\"servicingRadiologistID\", \"coord_x\", \"coord_y\", \"coord_z\", \"diameter\",\n                           \"malscore\",\"sphericiy\", \"margin\", \"spiculation\", \"texture\", \"calcification\", \"internal_structure\", \"lobulation\", \"subtlety\"]\n\n\ndef merge_nodule_annotation_csv_to_one(nodule_annotation_csv_list,save_file):\n    '''\n\n    :param nodule_annotation_csv_list:\n    :param save_file:\n    :return:\n    '''\n    annotattion_list = []\n    # add patient id to annotation information csv file\n    annotation_head = 'patient_id,'+','.join(pos_annotation_csv_head)\n    for csv in nodule_annotation_csv_list:  # csv filename like : 1.3.6.1.4.1.14519.5.2.1.6279.6001.106630482085576298661469304872_annos_pos.csv\n        with open(csv,'r') as read_csv:\n            contents = read_csv.readlines()\n            patient_id = os.path.basename(csv).split(\"_\")[0]  # get patient id\n            for line in contents[1:]:                         # skip column head\n                annotattion_list.append(patient_id+\",\"+line)\n    with open(save_file,'w') as save_csv:\n        save_csv.write(annotation_head+\"\\r\\n\")\n        for line in annotattion_list:\n            save_csv.write(line)\n    print(\"csv annotation file merged into file \",save_file)\n\ndef read_nodule_annotation_from_xml(xml_path,patient_mhd_path_dict,agreement_threshold=0):\n    '''\n         read annotaion information xml file path\n\n    :param xml_path:                single xml annotation file\n    :param patient_mhd_path_dict:   list of all mhd file paths,xml annotation contains several patient,every patient should get real mhd file path from it\n    :param agreement_threshold:     every patient's CT image was marked by multi(4 most) doctor,the least agreement to make final mark\n    :return:\n    '''\n    pos_lines = []\n    neg_lines = []\n    extended_lines = []\n    with open(xml_path, 'r') as xml_file:\n        markup = xml_file.read()\n    xml = BeautifulSoup(markup, features=\"xml\")\n    if xml.LidcReadMessage is None:\n        return None, None, None\n    patient_id = xml.LidcReadMessage.ResponseHeader.SeriesInstanceUid.text\n    src_path = None\n    if patient_id in patient_mhd_path_dict:\n        src_path = patient_mhd_path_dict[patient_id]\n    if src_path is None:\n        return None, None, None\n\n    print(patient_id)\n    mhd_info_pd = pd.read_csv(mhd_info_csv)\n    mhd_info_row = mhd_info_pd[mhd_info_pd['patient_id']==patient_id]\n    print(\"information about this patient is:\\t\",mhd_info_row)\n    num_z, height, width = list(mhd_info_row['shape_2'])[0],list(mhd_info_row['shape_1'])[0],list(mhd_info_row['shape_0'])[0]\n    print(\"num_z,height,width are:\\t\",num_z, height, width)\n    origin_x,origin_y,origin_z = list(mhd_info_row['origin_x'])[0],list(mhd_info_row['origin_y'])[0],list(mhd_info_row['origin_z'])[0]\n    spacing_x,spacing_y,spacing_z = list(mhd_info_row['spacing_x'])[0],list(mhd_info_row['spacing_y'])[0],list(mhd_info_row['spacing_z'])[0]\n\n    #  a reading session consists of the results consists of a set of markings done by a single\n    # reader at a single phase (for these xml files, the unblinded reading phase).\n    reading_sessions = xml.LidcReadMessage.find_all(\"readingSession\")\n\n    for reading_session in reading_sessions:\n        # print(\"Sesion\")\n        servicingRadiologistID = reading_session.servicingRadiologistID.text\n        nodules = reading_session.find_all(\"unblindedReadNodule\")\n        for nodule in nodules:\n            nodule_id = nodule.noduleID.text\n            # print(\"  \", nodule.noduleID)\n            rois = nodule.find_all(\"roi\")\n            x_min = y_min = z_min = 999999\n            x_max = y_max = z_max = -999999\n            if len(rois) < 2:\n                continue\n\n            for roi in rois:\n                z_pos = float(roi.imageZposition.text)\n                z_min = min(z_min, z_pos)\n                z_max = max(z_max, z_pos)\n                edge_maps = roi.find_all(\"edgeMap\")\n                for edge_map in edge_maps:\n                    x = int(edge_map.xCoord.text)\n                    y = int(edge_map.yCoord.text)\n                    x_min = min(x_min, x)\n                    y_min = min(y_min, y)\n                    x_max = max(x_max, x)\n                    y_max = max(y_max, y)\n                if x_max == x_min:\n                    continue\n                if y_max == y_min:\n                    continue\n\n            x_diameter = x_max - x_min\n            x_center = x_min + x_diameter / 2\n            y_diameter = y_max - y_min\n            y_center = y_min + y_diameter / 2\n            z_diameter = z_max - z_min\n            z_center = z_min + z_diameter / 2\n            z_center -= origin_z\n            z_center /= spacing_z\n\n            x_center_perc = round(x_center / width, 4)\n            y_center_perc = round(y_center / height, 4)\n            z_center_perc = round(z_center / num_z, 4)\n            diameter = max(x_diameter , y_diameter)\n            diameter_perc = round(max(x_diameter / width, y_diameter / height), 4)\n\n            if nodule.characteristics is None:\n                print(\"!!!!Nodule:\", nodule_id, \" has no charecteristics\")\n                continue\n            if nodule.characteristics.malignancy is None:\n                print(\"!!!!Nodule:\", nodule_id, \" has no malignacy\")\n                continue\n            print(\"nodule in load xml\",x_center_perc,y_center_perc,z_center_perc)\n            malignacy = nodule.characteristics.malignancy.text\n            sphericiy = nodule.characteristics.sphericity.text\n            margin = nodule.characteristics.margin.text\n            spiculation = nodule.characteristics.spiculation.text\n            texture = nodule.characteristics.texture.text\n            calcification = nodule.characteristics.calcification.text\n            internal_structure = nodule.characteristics.internalStructure.text\n            lobulation = nodule.characteristics.lobulation.text\n            subtlety = nodule.characteristics.subtlety.text\n\n            line = [nodule_id, x_center_perc, y_center_perc, z_center_perc, diameter_perc, malignacy]\n            extended_line = [patient_id, nodule_id, servicingRadiologistID,x_center_perc, y_center_perc, z_center_perc,\n                             diameter_perc, malignacy,sphericiy, margin, spiculation, texture, calcification, internal_structure, lobulation, subtlety]\n\n            pos_lines.append(line)\n            extended_lines.append(extended_line)\n\n        nonNodules = reading_session.find_all(\"nonNodule\")\n        for nonNodule in nonNodules:\n            z_center = float(nonNodule.imageZposition.text)\n            z_center -= origin_z\n            z_center /= spacing_z\n            x_center = int(nonNodule.locus.xCoord.text)\n            y_center = int(nonNodule.locus.yCoord.text)\n            nodule_id = nonNodule.nonNoduleID.text\n            x_center_perc = round(x_center / width, 4)\n            y_center_perc = round(y_center / height, 4)\n            z_center_perc = round(z_center / num_z, 4)\n            diameter_perc = round(max(6 / width, 6 / height), 4)\n            # print(\"Non nodule!\", z_center)\n            line = [nodule_id, x_center_perc, y_center_perc, z_center_perc, diameter_perc, 0]\n            neg_lines.append(line)\n\n    if agreement_threshold > 1:\n        filtered_lines = []\n        for pos_line1 in pos_lines:\n            id1 = pos_line1[0]\n            x1 = pos_line1[1]\n            y1 = pos_line1[2]\n            z1 = pos_line1[3]\n            d1 = pos_line1[4]\n            overlaps = 0\n            for pos_line2 in pos_lines:\n                id2 = pos_line2[0]\n                if id1 == id2:\n                    continue\n                x2 = pos_line2[1]\n                y2 = pos_line2[2]\n                z2 = pos_line2[3]\n                d2 = pos_line1[4]\n                dist = math.sqrt(math.pow(x1 - x2, 2) + math.pow(y1 - y2, 2) + math.pow(z1 - z2, 2))\n                if dist < d1 or dist < d2:\n                    overlaps += 1\n            if overlaps >= agreement_threshold:\n                filtered_lines.append(pos_line1)\n            # else:\n            #     print(\"Too few overlaps\")\n        pos_lines = filtered_lines\n\n    df_annos = pd.DataFrame(pos_lines, columns=pos_annotation_csv_head)\n    df_annos.to_csv(extracted_annotation_info_root_path + patient_id + \"_annos_pos_lidc.csv\", index=False)\n    df_neg_annos = pd.DataFrame(neg_lines, columns=neg_annotation_csv_head)\n    df_neg_annos.to_csv(extracted_annotation_info_root_path + patient_id + \"_annos_neg_lidc.csv\", index=False)\n\n    return pos_lines, neg_lines, extended_lines\n\n\n\ndef process_lidc_annotations(xml_annotation_like,patient_mhd_path_dict,mhd_all_info_save_path,agreement_threshold=0):\n    '''\n        extract xml annotation information from xml path\n\n    :param xml_annotation_like:             used for glob,a file path string\n    :param patient_mhd_path_dict:           key is patient id,value is the mhd file path\n    :param mhd_all_info_save_path:          where to  save the extracted mhd information csv file(not mhd_info_csv)\n    :param agreement_threshold:             every nodule was annotated by multi-doctor,the least number\n    :return:\n    '''\n    file_no = 0\n    pos_count = 0\n    neg_count = 0\n    all_lines = []\n\n    xml_paths = glob.glob(xml_annotation_like)\n    for xml_path in xml_paths:\n        pos, neg, extended = read_nodule_annotation_from_xml(xml_path, patient_mhd_path_dict,agreement_threshold=agreement_threshold)\n        if pos is not None:\n            pos_count += len(pos)\n            neg_count += len(neg)\n            file_no += 1\n            all_lines += extended\n    df_annos = pd.DataFrame(all_lines, columns= all_annotation_csv_head)\n    df_annos.to_csv(mhd_all_info_save_path, index=False)\n\ndef extract_lidc_every_z_annotations(xml_like,every_z_save_csv,patient_mhd_path_dict):\n    xml_paths = glob.glob(xml_like)\n    with open(every_z_save_csv,\"w\") as anno_save:\n        anno_save.write(\"seriesuid,coord_percent_x, coord_percent_y,coord_mm_z,percent_diamater,mm_x,mm_y,diameter_mm,malscore\")\n        anno_save.write(\"\\r\\n\")\n        for xml_path in xml_paths:\n            extended = extract_every_z_from_lidc_xml(xml_path, patient_mhd_path_dict)\n            if extended is not None:\n                anno_save.write(str(extended).replace(\")\",\"\").replace(\"(\",\"\").replace(\"\\'\",\"\")+\"\\r\\n\")\n\n\ndef extract_every_z_from_lidc_xml(xml_path,patient_mhd_path_dict):\n    '''\n        extract every nodule(ROIs) from xml file,the above method `read_nodule_annotation_from_xml` extract center z coordinates of nodule\n\n        this method was used by UNet to produce more nodule mask\n\n    :param xml_path:                xml file of nodule annotation\n    :param patient_mhd_path_dict:   key is patient id,value is its full mhd file path\n    :return:                        list of every nodule's coordinates\n    '''\n    extended_lines = []\n    with open(xml_path, 'r') as xml_file:\n        markup = xml_file.read()\n    xml = BeautifulSoup(markup, features=\"xml\")\n    if xml.LidcReadMessage is None:\n        return None, None, None\n    patient_id = xml.LidcReadMessage.ResponseHeader.SeriesInstanceUid.text\n\n    print(\"patient id is:\\t\", patient_id)\n    if patient_id in patient_mhd_path_dict:\n        src_path = patient_mhd_path_dict[patient_id]\n    else:\n        return None, None, None\n\n    print(patient_id)\n    mhd_info_pd = pd.read_csv(mhd_info_csv)\n    mhd_info_row = mhd_info_pd[mhd_info_pd['patient_id'] == patient_id]\n    print(\"information about this patient is:\\t\", mhd_info_row)\n    num_z, height, width = list(mhd_info_row['shape_2'])[0], list(mhd_info_row['shape_1'])[0], \\\n                           list(mhd_info_row['shape_0'])[0]\n    print(\"num_z,height,width are:\\t\", num_z, height, width)\n    origin_x, origin_y, origin_z = list(mhd_info_row['origin_x'])[0], list(mhd_info_row['origin_y'])[0], \\\n                                   list(mhd_info_row['origin_z'])[0]\n    spacing_x, spacing_y, spacing_z = list(mhd_info_row['spacing_x'])[0], list(mhd_info_row['spacing_y'])[0], \\\n                                      list(mhd_info_row['spacing_z'])[0]\n\n    #  a reading session consists of the results consists of a set of markings done by a single\n    # reader at a single phase (for these xml files, the unblinded reading phase).\n    reading_sessions = xml.LidcReadMessage.find_all(\"readingSession\")\n\n    for reading_session in reading_sessions:\n        # print(\"Sesion\")\n        servicingRadiologistID = reading_session.servicingRadiologistID.text\n        nodules = reading_session.find_all(\"unblindedReadNodule\")\n        for nodule in nodules:\n            nodule_id = nodule.noduleID.text\n            if nodule.characteristics is None:\n                print(\"!!!!Nodule:\", nodule_id, \" has no charecteristics\")\n                continue\n            if nodule.characteristics.malignancy is None:\n                print(\"!!!!Nodule:\", nodule_id, \" has no malignacy\")\n                continue\n            malignacy = nodule.characteristics.malignancy.text\n\n            rois = nodule.find_all(\"roi\")\n            x_min = y_min = z_min = 999999\n            x_max = y_max = z_max = -999999\n            if len(rois) < 2:\n                continue\n\n            for roi in rois:\n                z_pos = float(roi.imageZposition.text)\n                edge_maps = roi.find_all(\"edgeMap\")\n                for edge_map in edge_maps:\n                    x = int(edge_map.xCoord.text)\n                    y = int(edge_map.yCoord.text)\n                    x_min = min(x_min, x)\n                    y_min = min(y_min, y)\n                    x_max = max(x_max, x)\n                    y_max = max(y_max, y)\n                if x_max == x_min:\n                    continue\n                if y_max == y_min:\n                    continue\n\n                x_diameter = x_max - x_min\n                x_center = x_min + x_diameter / 2\n                y_diameter = y_max - y_min\n                y_center = y_min + y_diameter / 2\n\n                x_center_perc = round(x_center / width, 4)\n                y_center_perc = round(y_center / height, 4)\n                diameter_mm = max(x_diameter, y_diameter)\n                diameter_perc = round(max(x_diameter / width, y_diameter / height), 4)\n\n                extended_line = patient_id+\",\"+str(round(x_center_perc,4))+\",\"+str(round(y_center_perc,4))+\",\"+str(z_pos)+\",\"+\\\n                                 str(round(diameter_perc,4))+\",\"+str(x_center),str(y_center)+\",\"+str(diameter_mm)+\",\" +malignacy\n                extended_lines.append(extended_line)\n\n    return extended_lines\n\n\n"
  },
  {
    "path": "data/preprocessing/lidc_process/lidc_coordinate_process.py",
    "content": "# -*- coding:utf-8 -*-\n'''\n a script to process result produced by lidc_annotation_process.py\n\n mainly focus on process mm coordinates of annotations\n'''\n\nimport math\nimport os\nimport pandas as pd\nfrom constant import luna\nfrom util import mhd_util,image_util, cube\n\n\ndef draw_percent_cube_by_csv(percent_csv,mhd_info_csv,cube_save_path):\n    '''\n       coordinate from xml are percent of image shape,\n    :param percent_coordinate_csv:\n    :return:\n    '''\n    mhd_pandas_index = mhd_util.read_csv_to_pandas(mhd_info_csv,',')\n    percent_pandas = mhd_util.read_csv_to_pandas(percent_csv,',')\n\n    for index, row in percent_pandas.iterrows():\n        patient_id = row['patient_id']\n        coord_x = float(row['coord_x'])\n        coord_y = float(row['coord_y'])\n        coord_z = float(row['coord_z'])\n        malscore = row['malscore']\n        diameter = float(row['diameter'])\n        print(\"read cube coordx,coordy,coordz,diamater,malscore\",coord_x,coord_y,coord_z,diameter,malscore)\n        png_path = luna.LUNA_EXTRACTED_IMG + '/' + patient_id\n        if os.path.exists(png_path):\n            cube.draw_percent_cube(png_path, mhd_pandas_index, coord_x, coord_y,\n                                   coord_z, diameter, cube_save_path, probility =malscore)\n        else:\n            print(\"one patient not exists..\", png_path)\n\n\ndef percent_coordinatecsv_to_mmcsv(percent_csv,mhd_info_csv,mmcsv_save):\n    '''\n       transform percent coordinates to mm coordinates and save into new csv file\n\n    :param percent_csv:     csv file contains all percent coordinates\n    :param mmcsv_save:      csv file to save transformed coordinates\n    :return:\n    '''\n    with open(percent_csv,'r') as percent_read:\n        head = percent_read.readline()\n        print(str(head))\n        head = str(head).replace(\"coord_x,coord_y,coord_z\",\"coord_x,coord_y,coord_z,mm_x,mm_y,mm_z\")\n        extend_mm_coordinate_content = []\n        lines = percent_read.readlines()\n        for line in lines:\n            cols = line.split(\",\")\n            patient_id = cols[0]\n            p_x,p_y,p_z = cols[3],cols[4],cols[5] #row['coord_x'],row['coord_y'],row['coord_z']\n            print(\"one line\\t\",line)\n            print(\"patient id = \",patient_id)\n            print(\"coordinates is:\\t \",p_x,p_y,p_z)\n            mm_x,mm_y,mm_z = percent_coordinate_to_mm(patient_id,float(p_x),float(p_y),float(p_z),mhd_info_csv)\n            print(\"after transfered ..:\\t\",mm_x,mm_y,mm_z)\n            line = line.replace(str(p_x)+\",\"+str(p_y)+\",\"+str(p_z),\n                                str(p_x)+\",\"+str(p_y)+\",\"+str(p_z)+\",\"+str(mm_x)+\",\"+str(mm_y)+\",\"+str(mm_z))\n            print(line)\n            extend_mm_coordinate_content.append(line)\n\n        with open(mmcsv_save,\"w\") as mm:\n            mm.write(head)\n            for line in extend_mm_coordinate_content:\n                mm.write(str(line))\n    print(\"transformed mm coordinates finished..\")\n\n\ndef avg_coordinates(csv,threshold,csv_save):\n    '''\n        add average coordinates to source.\n     before:\n        patient_id0, coord_x0,coord_y0,coord_z0\n        patient_id0, coord_x1,coord_y1,coord_z1\n     after:\n        patient_id0, coord_x0,coord_y0,coord_z0,avg_x,avg_y,avg_z\n        patient_id0, coord_x1,coord_y1,coord_z1,avg_x,avg_y,avg_z\n    :param csv:\n    :param csv_save:        csv file to save transformed content\n    :return:\n    '''\n    new_content = []\n    patient_coords = {}\n    with open(csv,'r') as csv_read:\n        head = csv_read.readline()\n        lines = csv_read.readlines()\n        for line in lines:\n            cols = line.split(\",\")\n            patient_id = cols[0]\n            mm_x,mm_y,mm_z,diam,mals = cols[6],cols[7],cols[8],cols[9],cols[10]\n            if patient_id in patient_coords:\n                patient_coords[patient_id].append([mm_x,mm_y,mm_z,diam,mals])\n            else:\n                patient_coords[patient_id]= [[mm_x, mm_y, mm_z, diam,mals]]\n\n    # get average coords\n    avg_same_coords = {}\n\n    for key,value in patient_coords.items():\n        patient_id = key\n        coords = value\n        xyzs = []\n        for coor in coords: # list of lists\n            x,y,z = coor[0],coor[1],coor[2]\n            xyzs.append([float(x),float(y),float(z)])\n\n        # find every coordinate's neighbor coordinates (distance smaller than threshold)\n        same_coords = {}\n        coord_num = len(xyzs)\n        i = 0\n        while i <coord_num:\n            current_x,current_y,current_z =xyzs[i][0],xyzs[i][1],xyzs[i][2]\n            same_with_current = str(current_x)+\",\"+str(current_y)+\",\"+str(current_z)\n            same_coords[same_with_current] = [[current_x,current_y,current_z]]   # put itself into its neighbor\n            j = 0\n            while j < coord_num:\n                x,y,z = xyzs[j][0],xyzs[j][1],xyzs[j][2]\n                dis = math.sqrt((x-current_x)**2+(y-current_y)**2+(z-current_z)**2)            # distance of two coordinates\n                if dis< threshold and [x,y,z] not in same_coords[same_with_current]:\n                    same_coords[same_with_current].append([x,y,z])\n                j = j+1\n            i+=1\n\n        # get average of coordinates\n        for key,value in same_coords.items():\n            cur_x,curr_y,curr_z = key.split(\",\")\n            x,y,z = 0,0,0\n            #print(\"key:  value: \",key,\":\\t\",value)\n            if len(value)>0:\n                for same_cor in value:\n                    #print(same_cor)\n                    x = x + same_cor[0]\n                    y = y + same_cor[1]\n                    z = z + same_cor[2]\n                x = round(x /len(value),2)\n                y = round(y /len(value),2)\n                z = round(z /len(value),2)\n                # update dict with average\n            avg_same_coords[key] = [x,y,z]\n\n    with open(csv,'r') as csv_read:\n        head = csv_read.readline()\n        head = head.replace(\"mm_x,mm_y,mm_z\",\"mm_x,mm_y,mm_z,avg_x,avg_y,avg_z\")\n        new_content.append(head)\n        lines = csv_read.readlines()\n        for line in lines:\n            cols = line.split(\",\")\n            patient_id = cols[0]\n            mm_x,mm_y,mm_z,diam,mals = cols[6],cols[7],cols[8],cols[9],cols[10]\n            avg_xyz = avg_same_coords[mm_x+\",\"+mm_y+\",\"+mm_z]\n            avg_x,avg_y,avg_z = str(avg_xyz[0]),str(avg_xyz[1]),str(avg_xyz[2])\n            new_content.append(line.replace(mm_x+\",\"+mm_y+\",\"+mm_z,\n                                           mm_x + \",\" + mm_y + \",\" + mm_z+\",\"+avg_x+\",\"+avg_y+\",\"+avg_z))\n\n    with open(csv_save,'w') as info:\n        for line in new_content:\n            print(\"line from lidc_coordinate:\\t\",line)\n            info.write(line)\n\n    print(\"write attachement information to %s finished..\"%csv_save)\n\n\ndef add_final_mals(csv,with_real_malsclabel_csv):\n    \"\"\"\n            compute real malignancy of every patient. every nodule was labeled by several different readers,\n        this step is comfirming a final malignancy label\n    :param csv:                         csv file of all mhd information\n    :param with_real_malsclabel_csv:    csv file after add real malscore columns\n    :return:\n    \"\"\"\n    nodule_mals = {}\n    with open(csv,'r') as read_csv:\n        head = read_csv.readline()\n        lines = read_csv.readlines()\n        for line in lines:\n            cols = line.split(\",\")\n            patient_id = cols[0]\n            avg_x,avg_y,avg_z,mals = cols[9],cols[10],cols[11],cols[13]\n            key = patient_id+\",\"+avg_x+\",\"+avg_y+\",\"+avg_z\n            if key not in nodule_mals:\n                nodule_mals[key] = [int(mals)]\n            else:\n                nodule_mals[key].append(int(mals))\n\n    # compute the real malignancy label\n    for key,val in nodule_mals.items():\n        mals = val\n        print(\"patient_id and all malscore is:\\t\",key+\"\\t:\",val)\n        non_cancer = 0\n        unknow = 0\n        cancer = 0\n        UNK =\"unknow\"\n        for mal in mals:\n            if mal<3:\n                non_cancer +=1\n            elif mal ==3:\n                unknow +=1\n            elif mal>3:\n                cancer+=1\n        real_mal = \"\"\n        if unknow == len(mals):         # all are unknow\n            real_mal = UNK\n        elif non_cancer/(non_cancer+cancer)>0.5:\n            real_mal = \"0\"\n        elif non_cancer/(non_cancer+cancer)==0.5:\n            real_mal =UNK\n        elif non_cancer/(non_cancer+cancer)<0.5:\n            real_mal = \"1\"\n\n        if real_mal==\"0\" and unknow>non_cancer:\n            real_mal = UNK\n\n        print(\"non_cancer,unk,cancer,real label\", non_cancer, unknow, cancer,real_mal)\n        # update the mal label\n        nodule_mals[key] = real_mal\n\n    # add real mal columns into csv file\n    with_real_mals_content = []\n    with open(csv,'r') as read_csv:\n        head = read_csv.readline()\n        head = head.replace(\"avg_x,avg_y,avg_z\",\"avg_x,avg_y,avg_z,real_mal\")\n        with_real_mals_content.append(head)\n        lines = read_csv.readlines()\n        for line in lines:\n            cols = line.split(\",\")\n            patient_id = cols[0]\n            avg_x,avg_y,avg_z,mals = cols[9],cols[10],cols[11],cols[13]\n            key = patient_id+\",\"+avg_x+\",\"+avg_y+\",\"+avg_z\n            real_mal = nodule_mals[key]\n            # average coordinates equal to source coordinates\n            if avg_x + \",\" + avg_y +\",\" + avg_z+\",\"+avg_x + \",\" + avg_y +\",\" + avg_z in line:\n                with_real_mals_content.append(line.replace(avg_x + \",\" + avg_y +\",\" + avg_z + \",\" + avg_x + \",\" + avg_y +\",\" + avg_z,\n                                                           avg_x + \",\" + avg_y + \",\" + avg_z +\",\"+ avg_x + \",\" + avg_y + \",\" + avg_z +\",\"+ real_mal))\n            else:\n                with_real_mals_content.append(line.replace(avg_x + \",\" + avg_y +\",\" + avg_z,\n                                                       avg_x + \",\" + avg_y + \",\" + avg_z + \",\"+ real_mal))\n\n    # write the final result with real malscore columns into file\n    with open(with_real_malsclabel_csv) as with_mal:\n        for line in with_real_mals_content:\n            with_mal.write(line)\n\n    print(\"write result with real malscore columns finished..\")\n\n\n\ndef percent_coordinate_to_mm(patient_id,p_x,p_y,p_z,mhd_info_csv):\n    \"\"\"\n      transform percent coordinate to mm coordinates\n\n    :param patient_id:       patient id,used for mapping information from mhd_info_csv\n    :param p_x:              x percent coordinate\n    :param p_y:\n    :param p_z:\n    :param mhd_info_csv:    a csv file contains  all mhd information ,such as shape,spacing,origion\n    :return:                transformed mm coordinate x,y,z\n    \"\"\"\n    png_path = png_path = luna.LUNA_EXTRACTED_IMG + '/'+patient_id\n    patient_img = image_util.load_patient_images(png_path, \"*_i.png\", [])\n    mhd_pandas_index = mhd_util.read_csv_to_pandas(mhd_info_csv,',')\n    patient_mhd_info = mhd_pandas_index.loc[patient_id]\n\n    z = int(p_z * patient_img.shape[0])\n    y = int(p_y * patient_img.shape[1])\n    x = int(p_x * patient_img.shape[2])\n    orgin_x = float(patient_mhd_info['origin_x'].strip())\n    orgin_y = float(patient_mhd_info['origin_y'].strip())\n    orgin_z = float(patient_mhd_info['origin_z'].strip())\n\n    right_x = x + orgin_x\n    right_y = y + orgin_y\n    right_z = z + orgin_z\n\n    return round(right_x,2),round(right_y,2),round(right_z,2)\n\ndef draw_all_confirmed_cubes(mm_coordinates_csv,mhd_info_csv,extract_png_path,save_path):\n    \"\"\"\n        draw all annotated nodule by luna2016 official\n    :param mm_coordinates_csv:\n    :param mhd_info_csv:\n    :param extract_png_path:\n    :param save_path:\n    :return:\n    \"\"\"\n    coordinates = pd.read_csv(mm_coordinates_csv)\n    count = 0\n    mhd_info = mhd_util.read_csv_to_pandas(mhd_info_csv)\n    for df_index, df_row in coordinates.iterrows():\n        patient_id = df_row['seriesuid']\n        patient_png_path = os.path.join(extract_png_path,patient_id)\n        mm_x = df_row['coordX']\n        mm_y = df_row['coordY']\n        mm_z = df_row['coordZ']\n        diameter = df_row['diameter_mm']\n        if os.path.exists(patient_png_path):\n            cube.draw_percent_cube(patient_png_path, mm_x, mm_y, mm_z, diameter, save_path, mhd_pandas_index =mhd_info)\n        else:\n            count +=1\n    print(\"draw all cubes finished...%d cubes missed\"%count)\n\n"
  },
  {
    "path": "data/preprocessing/luna16_invalid_nodule_filter.py",
    "content": "## 去掉 Luna2016 候选结节数据中 有问题的标注数据 以及 用在预测过程中的 错误结节\nimport numpy as np\ndef nodule_valid(ct_data, voxel_coord_x, voxel_coord_y,voxel_coord_z):\n    \"\"\"\n        判定当前结节是否 可以用来做训练cube 或者 扫描得到的cube 是否\n    :param ct_data:         已经转换为0-255 ，并且已经抽取到肺部区域数据的 CTData类\n    :param voxel_coord_x:   当前要判定的 cube的坐标中心位置\n    :param voxel_coord_y:\n    :param voxel_coord_z:\n    :return:                当前结节是否可用 True(可用) / False (不可用)\n    \"\"\"\n    lung_mask = ct_data.lung_seg_mask\n    # 检查坐标是否在肺部边界内\n    if (voxel_coord_z < 0 or voxel_coord_z >= lung_mask.shape[0] or\n            voxel_coord_y < 0 or voxel_coord_y >= lung_mask.shape[1] or\n            voxel_coord_x < 0 or voxel_coord_x >= lung_mask.shape[2]):\n        return False\n    # 获取周围半径为5个体素的区域\n    z_min = max(0, voxel_coord_z - 5)\n    z_max = min(lung_mask.shape[0], voxel_coord_z + 6)\n    y_min = max(0, voxel_coord_y - 5)\n    y_max = min(lung_mask.shape[1], voxel_coord_y + 6)\n    x_min = max(0, voxel_coord_x - 5)\n    x_max = min(lung_mask.shape[2], voxel_coord_x + 6)\n    # 提取周围区域的肺部掩码\n    neighborhood_mask = lung_mask[z_min:z_max, y_min:y_max, x_min:x_max]\n    # 计算肺部组织占比\n    lung_ratio = np.mean(neighborhood_mask)\n    # 如果周围区域肺部组织占比太低，可能是假阳性\n    if lung_ratio < 0.5:\n        return False\n\n    # 检查是否在肺部边缘\n    # 计算当前点在肺部掩码中的位置\n    if (0 < voxel_coord_z < lung_mask.shape[0] - 1 and\n            0 < voxel_coord_y < lung_mask.shape[1] - 1 and\n            0 < voxel_coord_x < lung_mask.shape[2] - 1):\n        # 计算6-邻域（上下左右前后）中肺部体素的数量\n        neighbors = [\n            lung_mask[voxel_coord_z - 1, voxel_coord_y, voxel_coord_x],\n            lung_mask[voxel_coord_z + 1, voxel_coord_y, voxel_coord_x],\n            lung_mask[voxel_coord_z, voxel_coord_y - 1, voxel_coord_x],\n            lung_mask[voxel_coord_z, voxel_coord_y + 1, voxel_coord_x],\n            lung_mask[voxel_coord_z, voxel_coord_y, voxel_coord_x - 1],\n            lung_mask[voxel_coord_z, voxel_coord_y, voxel_coord_x + 1]\n        ]\n        # 如果邻域中有过多非肺部体素，说明这可能是在肺部边缘\n        if sum(neighbors) < 4:\n            return False\n    return True"
  },
  {
    "path": "data/preprocessing/luna16_prepare_cube_data.py",
    "content": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\nimport os\nimport pandas as pd\nfrom tqdm import tqdm\nimport multiprocessing as mp\nimport time\nfrom data.dataclass.CTData import CTData\nfrom data.dataclass.NoduleCube import NoduleCube\nfrom data.preprocessing.luna16_invalid_nodule_filter import nodule_valid\n\ndef get_mhd_file_path(patient_id, luna16_root=\"H:/luna16\"):\n    \"\"\"\n    根据patient_id查找对应的MHD文件路径\n\n    Args:\n        patient_id: LUNA16数据集中的患者ID\n        luna16_root: LUNA16数据集根目录\n\n    Returns:\n        MHD文件的完整路径\n    \"\"\"\n    # LUNA16数据集的子集文件夹\n    subsets = [f\"subset{i}\" for i in range(10)]\n    # 遍历所有子集查找对应的MHD文件\n    for subset in subsets:\n        subset_path = os.path.join(luna16_root, subset)\n        if os.path.exists(subset_path):\n            mhd_file = os.path.join(subset_path, f\"{patient_id}.mhd\")\n            if os.path.exists(mhd_file):\n                return mhd_file\n\n    # 未找到对应的MHD文件\n    print(f\"警告: 未找到患者 {patient_id} 的MHD文件\")\n    return None\n\ndef get_real_candidate(mhd_root_dir, annotation_csv, candidate_csv,save_real_candidate_csv):\n    \"\"\"\n        官方给的 候选结节标注，其实存在问题，可能与真实结节位置相近，导致数据误判。我们需要根据一定规则剔除\n    :param mhd_root_dir:\n    :param annotation_csv:\n    :param candidate_csv:\n    :param save_real_candidate_csv:\n    :return:\n    \"\"\"\n    positive_df = pd.read_csv(annotation_csv)\n    part_positive_df = positive_df[[\"seriesuid\",\"coordX\", \"coordY\",\"coordZ\"]]\n    part_positive_df[\"class\"] = 1\n    negative_df = pd.read_csv(candidate_csv)\n    part_negative_df = negative_df[negative_df[\"class\"] == 0].copy()\n    concat_df = pd.concat([part_positive_df, part_negative_df],axis=0)\n    unique_seriesids = concat_df[\"seriesuid\"].unique().tolist()\n    # 最终保留哪些候选结节\n    keep_negative_df_list = []\n    remove_nodule_num = 0\n    # 找到同一个病人的 真实结节标注 和 候选结节标注\n    for seid in unique_seriesids:\n        # 先查看对应数据在不在\n        mhd_path = get_mhd_file_path(seid, mhd_root_dir)\n        if mhd_path is not None and os.path.exists(mhd_path):\n            one_serid_df = concat_df[concat_df[\"seriesuid\"] == seid].copy()\n            one_seid_postive_df = one_serid_df[one_serid_df[\"class\"] == 1].copy()\n            one_seid_negative_df = one_serid_df[one_serid_df[\"class\"] == 0].copy()\n            if one_seid_postive_df.shape[0] > 0 and one_seid_negative_df.shape[0] > 0:\n                for _, negative_row in one_seid_negative_df.iterrows():\n                    one_negative_coord_x = negative_row[\"coordX\"]\n                    one_negative_coord_y = negative_row[\"coordY\"]\n                    one_negative_coord_z = negative_row[\"coordZ\"]\n                    keep_current_nodule = True\n                    for _, positive_row in one_seid_postive_df.iterrows():\n                        one_positive_coord_x = positive_row[\"coordX\"]\n                        one_positive_coord_y = positive_row[\"coordY\"]\n                        one_positive_coord_z = positive_row[\"coordZ\"]\n                        x_dist = abs(one_negative_coord_x - one_positive_coord_x)\n                        y_dist = abs(one_negative_coord_y - one_positive_coord_y)\n                        z_dist = abs(one_negative_coord_z - one_positive_coord_z)\n                        # 最大的结节的直径是32，我们必须保证所有候选结节 不与真实结节有重叠\n                        if x_dist < 16 or y_dist < 16 or z_dist < 16:\n                            keep_current_nodule = False\n                            remove_nodule_num = remove_nodule_num + 1\n                            break\n                    # 所有候选结节与真实结节都不重叠\n                    if keep_current_nodule:\n                        one_keep_nodule_df = pd.DataFrame([{\"seriesuid\": seid,\n                                                            \"coordX\": one_negative_coord_x,\n                                                            \"coordY\": one_negative_coord_y,\n                                                            \"coordZ\": one_negative_coord_z,\n                                                            \"class\": 0}])\n                        keep_negative_df_list.append(one_keep_nodule_df)\n    print(\"最终有多少结节记录\\t\", len(keep_negative_df_list))\n    print(\"剔除了多少个有问题的候选结节\\t\", remove_nodule_num)\n    if len(keep_negative_df_list) > 0:\n        keep_negative_df = pd.concat(keep_negative_df_list,axis = 0)\n        keep_negative_df.to_csv(save_real_candidate_csv, encoding=\"utf-8\", index = False)\n        print(\"最终结果保存到\\t\", save_real_candidate_csv)\n\ndef ctdata_annotation2nodule(ct_data, nodule_info, mal_label, cube_size=64):\n    \"\"\"\n    处理单个结节，提取立方体并保存为PNG和可视化图像\n\n    Args:\n        ct_data: CTData实例\n        nodule_info: 结节信息(Series)\n        mal_label:   当前结节标签\n        cube_size: 立方体大小(mm)\n    Returns:\n        保存的文件路径元组 (png_path, viz_path)\n    \"\"\"\n    # 获取结节信息\n    patient_id = nodule_info['seriesuid']\n    coord_x = nodule_info['coordX']\n    coord_y = nodule_info['coordY']\n    coord_z = nodule_info['coordZ']\n    # 从CT数据中提取结节立方体\n    nodule_cube_data = ct_data.extract_cube([coord_x, coord_y, coord_z], cube_size, if_fixed_radius=True)\n    center_voxel = ct_data.world_to_voxel([coord_x, coord_y, coord_z])\n    # 确保结节体积不为空\n    if nodule_cube_data.size == 0:\n        print(f\"警告: 患者 {patient_id} 的结节体积为空，请检查坐标是否正确\")\n        return None, None\n    # 打印原始数据形状\n    # 创建NoduleCube实例\n    nodule_cube = NoduleCube.from_array(\n        pixel_data=nodule_cube_data,\n        center_x=int(center_voxel[0]),\n        center_y=int(center_voxel[1]),\n        center_z=int(center_voxel[2]),\n        radius=int(cube_size / 2),\n        malignancy=mal_label\n    )\n    return nodule_cube\n\n\ndef process_nodule(nodule, cube_index,mhd_root_dir, label, if_aug, png_output, npy_output, check_output, check_count_limit):\n    \"\"\"\n    处理单个结节的工作函数，适用于多进程\n\n    Args:\n        nodule: 结节信息（DataFrame的一行）\n        mhd_root_dir: MHD文件的根目录\n        label: 标签(0=良性, 1=恶性)\n        if_aug: 是否需要做数据增强\n        png_output: PNG输出目录\n        npy_output: NPY输出目录\n        check_output: 检查输出目录\n        check_count_limit: 检查图像的最大数量\n        cube_index: 结节索引\n\n    Returns:\n        处理结果信息字典\n    \"\"\"\n    result = {\n        'success': False,\n        'patient_id': nodule['seriesuid'],\n        'error': None,\n        'check_saved': False\n    }\n\n    patient_id = nodule['seriesuid']\n    # 获取MHD文件路径\n    mhd_path = get_mhd_file_path(patient_id, mhd_root_dir)\n    if mhd_path is None:\n        result['error'] = \"MHD文件未找到\"\n        return result\n\n    patient_save_png_path = os.path.join(png_output, f\"{patient_id}_mal={label}_{cube_index}.png\")\n    patient_save_npy_path = os.path.join(npy_output, f\"{patient_id}_mal={label}_{cube_index}.npy\")\n\n    # 如果文件已存在，跳过处理\n    if os.path.exists(patient_save_png_path):\n        result['success'] = True\n        result['error'] = \"文件已存在，跳过处理\"\n        return result\n\n    try:\n        ct_data = CTData.from_mhd(mhd_path)\n        ct_data.resample_pixel()\n        ct_data.filter_lung_img_mask()\n\n        # 处理结节\n        one_nodule_cube = ctdata_annotation2nodule(ct_data, nodule, mal_label=label, cube_size=32)\n\n        if one_nodule_cube is None:\n            result['error'] = \"结节体积为空\"\n            return result\n        # 检查 良性结节的候选标注是否妥当。注意要先执行 resample_pixel 和 filter_lung_img_mask\n        if label == 0:\n            voxel_coord_x = one_nodule_cube.center_x\n            voxel_coord_y = one_nodule_cube.center_y\n            voxel_coord_z = one_nodule_cube.center_z\n            this_nodule_valid = nodule_valid(ct_data, voxel_coord_x, voxel_coord_y, voxel_coord_z)\n            if not this_nodule_valid:\n                result['error'] = \"结节不理想\"\n                return result\n        # 保存为PNG拼接图和NPY文件\n        one_nodule_cube.save_to_png(patient_save_png_path)\n        one_nodule_cube.save_to_npy(patient_save_npy_path)\n        if if_aug:\n            # 3 次旋转\n            rotation_nodule_cube1 = one_nodule_cube.augment(rotation=True)\n            patient_save_png_path_rotatoion1 = patient_save_png_path.replace(\".png\", \"_rotation1.png\")\n            patient_save_npy_path_rotatoion1 = patient_save_npy_path.replace(\".npy\", \"_rotation1.npy\")\n            rotation_nodule_cube1.save_to_png(patient_save_png_path_rotatoion1)\n            rotation_nodule_cube1.save_to_npy(patient_save_npy_path_rotatoion1)\n\n            rotation_nodule_cube2 = one_nodule_cube.augment(rotation=True)\n            patient_save_png_path_rotatoion2 = patient_save_png_path.replace(\".png\", \"_rotation2.png\")\n            patient_save_npy_path_rotatoion2 = patient_save_npy_path.replace(\".npy\", \"_rotation2.npy\")\n            rotation_nodule_cube2.save_to_png(patient_save_png_path_rotatoion2)\n            rotation_nodule_cube2.save_to_npy(patient_save_npy_path_rotatoion2)\n\n            rotation_nodule_cube3 = one_nodule_cube.augment(rotation=True)\n            patient_save_png_path_rotatoion3 = patient_save_png_path.replace(\".png\", \"_rotation3.png\")\n            patient_save_npy_path_rotatoion3 = patient_save_npy_path.replace(\".npy\", \"_rotation3.npy\")\n            rotation_nodule_cube3.save_to_png(patient_save_png_path_rotatoion3)\n            rotation_nodule_cube3.save_to_npy(patient_save_npy_path_rotatoion3)\n            # 3次翻转\n            flip_nodule_cube1 = one_nodule_cube.augment(rotation=False, flip_axis = 0)\n            patient_save_png_path_flip1 = patient_save_png_path.replace(\".png\", \"_flip1.png\")\n            patient_save_npy_path_flip1 = patient_save_npy_path.replace(\".npy\", \"_flip1.npy\")\n            flip_nodule_cube1.save_to_png(patient_save_png_path_flip1)\n            flip_nodule_cube1.save_to_npy(patient_save_npy_path_flip1)\n\n            flip_nodule_cube2 = one_nodule_cube.augment(rotation=False, flip_axis=1)\n            patient_save_png_path_flip2 = patient_save_png_path.replace(\".png\", \"_flip2.png\")\n            patient_save_npy_path_flip2 = patient_save_npy_path.replace(\".npy\", \"_flip2.npy\")\n            flip_nodule_cube2.save_to_png(patient_save_png_path_flip2)\n            flip_nodule_cube2.save_to_npy(patient_save_npy_path_flip2)\n\n            flip_nodule_cube3 = one_nodule_cube.augment(rotation=False, flip_axis=2)\n            patient_save_png_path_flip3 = patient_save_png_path.replace(\".png\", \"_flip3.png\")\n            patient_save_npy_path_flip3 = patient_save_npy_path.replace(\".npy\", \"_flip3.npy\")\n            flip_nodule_cube3.save_to_png(patient_save_png_path_flip3)\n            flip_nodule_cube3.save_to_npy(patient_save_npy_path_flip3)\n            # 3次只加噪音\n            noise_nodule_cube1 = one_nodule_cube.augment(rotation=False, flip_axis=-1, noise=True)\n            patient_save_png_path_noise1 = patient_save_png_path.replace(\".png\", \"_noise1.png\")\n            patient_save_npy_path_noise1 = patient_save_npy_path.replace(\".npy\", \"_noise1.npy\")\n            noise_nodule_cube1.save_to_png(patient_save_png_path_noise1)\n            noise_nodule_cube1.save_to_npy(patient_save_npy_path_noise1)\n\n            noise_nodule_cube2 = one_nodule_cube.augment(rotation=False, flip_axis=-1, noise=True)\n            patient_save_png_path_noise2 = patient_save_png_path.replace(\".png\", \"_noise2.png\")\n            patient_save_npy_path_noise2 = patient_save_npy_path.replace(\".npy\", \"_noise2.npy\")\n            noise_nodule_cube2.save_to_png(patient_save_png_path_noise2)\n            noise_nodule_cube2.save_to_npy(patient_save_npy_path_noise2)\n\n            noise_nodule_cube3 = one_nodule_cube.augment(rotation=False, flip_axis=-1, noise=True)\n            patient_save_png_path_noise3 = patient_save_png_path.replace(\".png\", \"_noise3.png\")\n            patient_save_npy_path_noise3 = patient_save_npy_path.replace(\".npy\", \"_noise3.npy\")\n            noise_nodule_cube3.save_to_png(patient_save_png_path_noise3)\n            noise_nodule_cube3.save_to_npy(patient_save_npy_path_noise3)\n\n        # 判断是否需要保存检查图像\n        if cube_index < check_count_limit:\n            viz_path = os.path.join(check_output, f\"{patient_id}_nodule_mal={label}_viz_{cube_index}.png\")\n            one_nodule_cube.visualize_3d(output_path=viz_path, show=False)\n            result['check_saved'] = True\n\n        result['success'] = True\n\n    except Exception as e:\n        result['error'] = str(e)\n    return result\n\ndef prepare_cubes_mp(mhd_root_dir, annotation_csv, label, png_output, npy_output, check_output,if_aug = False, num_processes=None, max_samples = 60000):\n    \"\"\"\n    多进程版本：从标注csv文件和mhd目录创建结节立方块\n\n    Args:\n        mhd_root_dir: mhd文件的根目录\n        annotation_csv: 结节标注文件\n        label: 当前标注文件良(0)/恶(1)性\n        png_output: 结节立方块图片保存目录\n        npy_output: 结节立方块数据保存目录\n        check_output: 用于检查结节数据抽取是否准确的目录\n        if_aug:     是否做数据增强，恶性结节数据较少，需要做增强\n        num_processes: 进程数量，默认为CPU核心数的80%\n        max_samples:   最大样本数，主要是针对负样本，因为正样本增强后也只有 9300，负样本有好几万\n    \"\"\"\n    # 创建输出目录\n    os.makedirs(png_output, exist_ok=True)\n    os.makedirs(npy_output, exist_ok=True)\n    os.makedirs(check_output, exist_ok=True)\n\n    # 确定进程数\n    if num_processes is None:\n        num_processes = max(1, int(mp.cpu_count() * 0.8))\n\n    # 加载标注数据\n    annotations_df = pd.read_csv(annotation_csv, encoding=\"utf-8\")\n    print(annotations_df[\"class\"].unique().tolist())\n    # 确定只使用class=0 的记录，候选集里面还是有大量 class=1的数据\n    if \"class\" in annotations_df.columns:\n        annotations_df = annotations_df[annotations_df[\"class\"] == 0].copy()\n\n    annotations_df = annotations_df[:max_samples]\n    total_nodules = len(annotations_df)\n\n    print(f\"开始处理 {total_nodules} 个{'恶性' if label == 1 else '良性'}结节，使用 {num_processes} 个进程\")\n\n    # 设置检查图像的数量限制\n    check_count_limit = min(10, total_nodules)\n\n    # 创建参数列表\n    args_list = [\n        (\n            row,  # nodule\n            i,  # cube_index\n            mhd_root_dir,\n            label,\n            if_aug,\n            png_output,\n            npy_output,\n            check_output,\n            check_count_limit\n        )\n        for i, row in annotations_df.iterrows()\n    ]\n    # 使用进程池处理数据\n    start_time = time.time()\n    with mp.Pool(processes=num_processes) as pool:\n        # 使用imap返回结果按顺序处理，并显示进度条\n        results = list(tqdm(\n            pool.starmap(process_nodule, args_list),\n            total=total_nodules,\n            desc=f\"处理 {'恶性' if label == 1 else '良性'} 结节\"\n        ))\n\n    # 统计处理结果\n    success_count = sum(1 for r in results if r['success'])\n    error_count = sum(1 for r in results if not r['success'])\n\n    # 计算处理时间\n    elapsed_time = time.time() - start_time\n    avg_time_per_nodule = elapsed_time / total_nodules if total_nodules > 0 else 0\n\n    print(\n        f\"处理完成！成功: {success_count}, 失败: {error_count}, 总共耗时: {elapsed_time:.2f}秒, 平均每个结节: {avg_time_per_nodule:.2f}秒\")\n\n    # 如果有错误，输出前5个错误\n    if error_count > 0:\n        print(\"错误样例:\")\n        error_samples = [r for r in results if not r['success']][:5]\n        for i, sample in enumerate(error_samples):\n            print(f\"  {i + 1}. 患者ID: {sample['patient_id']}, 错误: {sample['error']}\")\n\n    return success_count, error_count\n\n\ndef main():\n    \"\"\"主函数\"\"\"\n    # 设置路径\n    positive_nodule_annotation_file = \"H:/luna16/annotations.csv\"\n    negative_nodule_annotation_file = \"H:/luna16/candidates_V2.csv\"\n    save_real_candidate_csv = \"H:/luna16/candidates_clean.csv\"\n    luna16_root = \"H:/luna16\"\n    # 从候选结节中筛选出 合适的 负样本，不要直接使用，里面存在大量有问题的数据\n    get_real_candidate(luna16_root, positive_nodule_annotation_file, negative_nodule_annotation_file, save_real_candidate_csv)\n    # 输出目录\n    positive_png_save_dir = \"J:/luna16_processed/positive_pngs/\"\n    positive_npy_save_dir = \"J:/luna16_processed/positive_npys/\"\n    check_save_dir = \"J:/luna16_processed/check_pngs/\"\n    # 设置进程数，默认为CPU核心数的80%\n    num_processes = max(1, int(mp.cpu_count() * 0.8))\n    print(f\"系统检测到 {mp.cpu_count()} 个CPU核心，将使用 {num_processes} 个进程进行处理\")\n    # 处理恶性结节\n    print(\"\\n===== 处理恶性结节 =====\")\n    # success_pos, error_pos = prepare_cubes_mp(\n    #     luna16_root,\n    #     positive_nodule_annotation_file,\n    #     1,\n    #     positive_png_save_dir,\n    #     positive_npy_save_dir,\n    #     check_save_dir,\n    #     True,\n    #     num_processes\n    # )\n    # 处理良性结节\n    print(\"\\n===== 处理良性结节 =====\")\n    negative_png_save_dir = \"J:/luna16_processed/negative_pngs/\"\n    negative_npy_save_dir = \"J:/luna16_processed/negative_npys/\"\n    success_neg, error_neg = prepare_cubes_mp(\n        luna16_root,\n        save_real_candidate_csv,\n        0,\n        negative_png_save_dir,\n        negative_npy_save_dir,\n        check_save_dir,\n        False,\n        num_processes\n    )\n    # 总结处理结果\n    print(\"\\n===== 处理总结 =====\")\n    # print(f\"恶性结节: 成功 {success_pos}, 失败 {error_pos}\")\n    print(f\"良性结节: 成功 {success_neg}, 失败 {error_neg}\")\n    # print(f\"总计: 成功 {success_pos + success_neg}, 失败 {error_pos + error_neg}\")\n\nif __name__ == \"__main__\":\n    # 防止Windows多进程问题\n    mp.freeze_support()\n    main()"
  },
  {
    "path": "deploy/README.md",
    "content": "# CT图像分析系统\n\n这是一个基于深度学习的肺部CT图像分析系统，可以检测肺结节并评估其恶性概率。\n\n## 系统功能\n\n- 支持多种CT数据格式（DICOM、MHD/RAW等）\n- 全3D肺部可视化展示（体积视图、横断面、冠状面、矢状面）\n- 自动检测肺部结节并显示其位置和大小\n- 计算结节恶性概率并给出医学建议\n- 可交互式浏览结节预览\n\n## 系统要求\n\n- Python 3.11或更高版本\n- 浏览器: Chrome, Firefox, Edge（最新版本）\n- 操作系统: Windows 10/11, macOS, Linux\n\n## 依赖库\n\n- Flask: Web服务器\n- NumPy: 数值计算\n- TensorFlow: 深度学习框架\n- SimpleITK: 医学图像处理\n- Three.js: 3D渲染（已在界面中引用）\n\n## 安装\n\n1. 安装依赖库:\n\n```bash\npip install flask flask-cors numpy tensorflow SimpleITK\n```\n\n2. 运行系统:\n\n```bash\npython run.py\n```\n\n## 使用说明\n\n1. 启动系统后，将自动打开浏览器访问`http://localhost:5000`\n2. 点击\"上传CT数据\"按钮，选择CT数据文件\n3. 系统会自动进行以下处理：\n   - 上传数据\n   - 数据预处理\n   - 模型分析\n   - 结果可视化\n4. 处理完成后，可以在主界面查看3D肺部模型和检测到的结节\n5. 右侧结节预览区域显示所有检测到的结节，点击可高亮显示对应结节\n6. 可以使用视图切换按钮查看不同视角的肺部模型\n\n## 文件结构\n\n```\ndeploy/\n├── backend/              # 后端代码\n│   ├── models/           # 模型存储\n│   ├── uploads/          # 上传文件临时存储\n│   └── app.py            # 后端主程序\n├── frontend/             # 前端代码\n│   ├── css/              # 样式文件\n│   ├── js/               # JavaScript代码\n│   └── index.html        # 主页面\n├── run.py                # 启动脚本\n└── README.md             # 说明文档\n```\n\n## 注意事项\n\n- 本系统仅供研究和教学使用，不应用于实际医疗诊断\n- 大尺寸CT数据处理可能需要较长时间，请耐心等待\n- 结节检测结果仅供参考，实际诊断应由专业医生进行 "
  },
  {
    "path": "deploy/backend/app.py",
    "content": "import os,cv2\nimport sys\nimport numpy as np\nfrom flask import Flask, request, jsonify, send_from_directory, send_file, Response, make_response\nfrom flask_cors import CORS\nimport logging\nimport json\nimport uuid\nimport threading\nimport zipfile\nfrom PIL import Image\nfrom datetime import datetime\nimport io\nimport struct\nfrom deploy.backend.dataclass.CTData import CTData\nfrom util.dicom_util import is_dicom_file\nfrom io import BytesIO\nimport matplotlib\n# 设置matplotlib后端为非交互式\nmatplotlib.use('Agg')\nmatplotlib.rcParams['font.sans-serif'] = ['Microsoft YaHei']  # 或者你系统中的其他中文字体\nmatplotlib.rcParams['axes.unicode_minus'] = False  # 正确显示负号\nimport matplotlib.pyplot as plt\n\n# 添加项目根目录到系统路径\nsys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))\n# 导入检测器模块\nfrom detector import get_detector_instance, STATUS_COMPLETED, STATUS_ERROR\n\n# 配置日志\nlogging.basicConfig(level=logging.INFO)\nlogger = logging.getLogger(__name__)\n# 模型参数\nCUBE_SIZE = 32\n# 配置上传文件夹和模型文件夹\nUPLOAD_FOLDER = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'uploads')\nMODEL_FOLDER = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'data')\nos.makedirs(UPLOAD_FOLDER, exist_ok=True)\nos.makedirs(MODEL_FOLDER, exist_ok=True)\n# 允许的文件扩展名\nALLOWED_EXTENSIONS = {'mhd', 'raw', 'nii', 'nii.gz', 'dcm', 'dicom', 'zip'}\n# 会话数据存储 - 用于存储上传的文件和处理状态\nSESSION_DATA = {}\n# 创建Flask应用\napp = Flask(__name__, static_folder='../frontend')\nCORS(app)  # 启用跨域\n# 配置应用\napp.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER\napp.config['MAX_CONTENT_LENGTH'] = 500 * 1024 * 1024  # 500MB限制\n\n# 设置默认模型路径\nDEFAULT_MODEL_PATH = os.path.join(MODEL_FOLDER, 'c3d_nodule_detect.pth')\n# 初始化检测器实例\ndetector = get_detector_instance(DEFAULT_MODEL_PATH)\n\n\n# 添加检测完成回调函数\ndef on_detection_completed(session_id, results):\n    \"\"\"检测完成后的回调函数，更新会话状态\"\"\"\n    logger.info(f\"检测完成回调: 会话 {session_id}\")\n    if session_id in SESSION_DATA:\n        # 更新会话状态为completed\n        SESSION_DATA[session_id]['status'] = 'completed'\n        SESSION_DATA[session_id]['results'] = results\n        SESSION_DATA[session_id]['progress'] = 100\n        SESSION_DATA[session_id]['message'] = f\"检测完成，发现 {len(results.get('nodules', []))} 个结节\"\n\n        # 更新检测器中的状态\n        detector.update_session_state(session_id, SESSION_DATA[session_id])\n\n        # 生成结节图像\n        try:\n            nodules = results.get('nodules', [])\n            if nodules:\n                logger.info(f\"检测完成，开始为 {len(nodules)} 个结节生成图像\")\n                lung_seg_path = os.path.join(UPLOAD_FOLDER, session_id, 'lung_seg.npy')\n                if os.path.exists(lung_seg_path):\n                    lung_img = np.load(lung_seg_path)\n                    if lung_img is not None and lung_img.size > 0:\n                        logger.info(f\"成功加载肺部分割数据，形状: {lung_img.shape}\")\n                        success = generate_nodule_images(session_id, nodules, lung_img)\n                        if success:\n                            logger.info(f\"结节图像生成完成\")\n                        else:\n                            logger.error(f\"结节图像生成失败\")\n                    else:\n                        logger.error(f\"肺部分割数据为空或无效\")\n                else:\n                    logger.error(f\"肺部分割数据文件不存在: {lung_seg_path}\")\n            else:\n                logger.info(f\"无结节数据，跳过图像生成\")\n        except Exception as e:\n            logger.error(f\"生成结节图像时出错: {str(e)}\", exc_info=True)\n\n\n# 检查检测器是否支持回调并添加\nif hasattr(detector, 'set_completion_callback'):\n    detector.set_completion_callback(on_detection_completed)\nelse:\n    logger.warning(\"检测器不支持完成回调，状态更新可能不正确\")\n\n# 添加兼容性方法，确保detector支持update_session_state\nif not hasattr(detector, 'update_session_state'):\n    def update_session_state(session_id, state):\n        \"\"\"\n        更新会话状态的兼容性方法\n        确保session_states字典存在并更新状态\n        \"\"\"\n        if not hasattr(detector, 'session_states'):\n            detector.session_states = {}\n\n        if not hasattr(detector, 'session_locks'):\n            detector.session_locks = {}\n\n        # 创建会话锁（如果不存在）\n        if session_id not in detector.session_locks:\n            detector.session_locks[session_id] = threading.Lock()\n\n        # 更新会话状态\n        with detector.session_locks.get(session_id, threading.Lock()):\n            detector.session_states[session_id] = state.copy()  # 使用副本避免引用问题\n\n        logger.info(f\"更新会话 {session_id} 的状态: {state['status']}\")\n\n\n    # 将方法添加到检测器对象\n    detector.update_session_state = update_session_state\n\n    # 如果get_session_state方法也不存在，添加它\n    if not hasattr(detector, 'get_session_state'):\n        def get_session_state(session_id):\n            \"\"\"获取会话状态的兼容性方法\"\"\"\n            if not hasattr(detector, 'session_states'):\n                detector.session_states = {}\n\n            return detector.session_states.get(session_id)\n\n\n        detector.get_session_state = get_session_state\n\n\ndef allowed_file(filename):\n    \"\"\"检查文件扩展名是否允许\"\"\"\n    return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS\n\n\n@app.route('/')\ndef index():\n    \"\"\"提供前端页面\"\"\"\n    return send_from_directory('../frontend', 'index.html')\n\n\n@app.route('/<path:path>')\ndef static_files(path):\n    \"\"\"提供静态文件\"\"\"\n    return send_from_directory('../frontend', path)\n\n\n@app.route('/api/upload', methods=['POST'])\ndef upload_file():\n    \"\"\"\n    处理CT文件上传\n    支持上传包含DICOM或MHD/RAW文件的压缩包\n    自动解压并检测文件类型\n    \"\"\"\n    try:\n        # 检查是否有文件\n        if 'file' not in request.files:\n            return jsonify({'success': False, 'error': '没有找到文件'}), 400\n\n        file = request.files['file']\n\n        # 如果用户没有选择文件\n        if file.filename == '':\n            return jsonify({'success': False, 'error': '没有选择文件'}), 400\n\n        # 检查文件是否为zip\n        if not file.filename.lower().endswith('.zip'):\n            return jsonify({'success': False, 'error': '请上传ZIP格式的文件'}), 400\n\n        # 创建新的会话ID\n        session_id = str(uuid.uuid4())\n\n        # 创建会话目录\n        session_dir = os.path.join(UPLOAD_FOLDER, session_id)\n        os.makedirs(session_dir, exist_ok=True)\n\n        # 保存压缩文件\n        zip_path = os.path.join(session_dir, 'upload.zip')\n        file.save(zip_path)\n        app.logger.info(f\"保存压缩文件到 {zip_path}\")\n\n        # 解压文件\n        extract_dir = os.path.join(session_dir, 'extracted')\n        os.makedirs(extract_dir, exist_ok=True)\n\n        with zipfile.ZipFile(zip_path, 'r') as zip_ref:\n            zip_ref.extractall(extract_dir)\n        zip_filename = os.path.splitext(os.path.basename(file.filename))[0]\n        real_extracted_path = os.path.join(extract_dir, zip_filename)\n        print(\"真实的加压缩文件目录是\\t\", real_extracted_path)\n        app.logger.info(f\"解压文件到 {real_extracted_path}\")\n        # 检测文件类型\n        file_type, file_paths = detect_file_type(real_extracted_path)\n        print(\"检测到的类型\\t\", file_type)\n        print(\"检测到的文件列表\\t\", file_paths)\n        if not file_type:\n            return jsonify({'success': False, 'error': '压缩包中未找到支持的CT数据文件(DICOM或MHD/RAW)'}), 400\n\n        # 记录会话信息\n        session_info = {\n            'id': session_id,\n            'timestamp': datetime.now().isoformat(),\n            'file_type': file_type,\n            'files': file_paths,\n            'extract_dir': real_extracted_path,\n            'status': 'uploaded'\n        }\n\n        # 可选：保存患者ID\n        if 'patient_id' in request.form:\n            session_info['patient_id'] = request.form['patient_id']\n\n        # 保存会话信息到文件\n        session_info_path = os.path.join(session_dir, 'session_info.json')\n        with open(session_info_path, 'w') as f:\n            json.dump(session_info, f)\n\n        # 将会话ID添加到全局会话字典\n        SESSION_DATA[session_id] = {\n            'status': 'uploaded',\n            'progress': 0,\n            'message': f'已上传并解压{file_type}格式CT数据，等待检测',\n            'file_type': file_type,\n            'files': file_paths\n        }\n\n        # 如果配置了自动处理，启动预处理任务\n        auto_detect = app.config.get('AUTO_DETECT', False)\n        if auto_detect:\n            # 启动异步任务进行预处理\n            threading.Thread(target=start_preprocessing, args=(session_id,)).start()\n            app.logger.info(f\"启动自动预处理任务，会话ID: {session_id}\")\n\n        # 返回成功响应\n        return jsonify({\n            'success': True,\n            'message': f'文件上传成功，检测到{file_type}格式CT数据',\n            'session_id': session_id,\n            'file_type': file_type,\n            'auto_detect': auto_detect\n        })\n\n    except Exception as e:\n        app.logger.error(f\"文件上传失败: {str(e)}\", exc_info=True)\n        return jsonify({\n            'success': False,\n            'error': f\"文件上传失败: {str(e)}\"\n        }), 400\n\n\ndef detect_file_type(directory):\n    \"\"\"\n    检测目录中的CT数据类型\n    支持DICOM和MHD/RAW格式\n    返回: (文件类型, 文件路径列表)\n    \"\"\"\n    # 递归查找所有文件\n    file_list = []\n    for root, _, files in os.walk(directory):\n        for file in files:\n            file_list.append(os.path.join(root, file))\n    print(\"检查的文件夹是\\t\", directory)\n    print(file_list)\n    # 检查是否有MHD文件\n    mhd_files = [f for f in file_list if f.lower().endswith('.mhd')]\n    if mhd_files:\n        # 检查对应的RAW文件是否存在\n        for mhd_file in mhd_files:\n            # 获取对应的RAW文件名（替换扩展名）\n            raw_file = os.path.splitext(mhd_file)[0] + '.raw'\n            # 忽略大小写比较\n            if any(f.lower() == raw_file.lower() for f in file_list):\n                # 找到MHD和对应的RAW文件\n                return 'MHD/RAW', [mhd_file]\n\n    # 检查是否有DICOM文件\n    dicom_files = [f for f in file_list if f.lower().endswith(('.dcm', '.dicom')) or\n                   (os.path.isfile(f) and is_dicom_file(f))]\n\n    if dicom_files:\n        return 'DICOM', dicom_files\n\n    # 如果没有找到支持的文件类型\n    return None, []\n\n\ndef start_preprocessing(session_id):\n    \"\"\"\n    启动CT数据预处理\n    将CT数据转换为肺部分割数据\n    \"\"\"\n    try:\n        # 获取会话状态\n        if session_id not in SESSION_DATA:\n            logger.error(f\"会话 {session_id} 不存在\")\n            return\n\n        state = SESSION_DATA[session_id]\n        state['status'] = 'preprocessing'\n        state['progress'] = 0\n        state['message'] = '正在开始预处理...'\n\n        # 同步会话状态到检测器\n        detector.update_session_state(session_id, state)\n\n        # 获取会话目录\n        session_dir = os.path.join(UPLOAD_FOLDER, session_id)\n        # 加载会话信息\n        with open(os.path.join(session_dir, 'session_info.json'), 'r') as f:\n            session_info = json.load(f)\n        print(\"加载检测的信息是\\n\", session_info)\n\n        file_type = session_info['file_type']\n        extract_dir = session_info['extract_dir']  # 使用保存的解压目录\n\n        # 更新状态\n        state['progress'] = 10\n        state['message'] = f'正在加载{file_type}格式CT数据...'\n        # 同步状态更新\n        detector.update_session_state(session_id, state)\n\n        # 根据文件类型进行处理\n        ct_data = None\n\n        print(\"使用的解压目录是\\n\", extract_dir)\n\n        if file_type == 'MHD/RAW':\n            # 使用MHD文件的目录路径，而不是单个文件路径\n            ct_data = CTData.from_mhd(session_info['files'][0])\n\n        elif file_type == 'DICOM':\n            # 对于DICOM，使用包含所有DICOM文件的目录路径\n            ct_data = CTData.from_dicom(extract_dir)\n\n        # 检查是否成功加载\n        if ct_data is None:\n            state['status'] = 'error'\n            state['message'] = '加载CT数据失败'\n            SESSION_DATA[session_id] = state\n            detector.update_session_state(session_id, state)\n            logger.error(f\"加载CT数据失败，会话ID: {session_id}\")\n            return\n\n        # 更新状态\n        state['progress'] = 30\n        state['message'] = '正在进行肺部分割...'\n        # 同步状态更新\n        detector.update_session_state(session_id, state)\n\n        ct_data.resample_pixel()\n        ct_data.filter_lung_img_mask()\n        # 执行肺部分割\n        lung_seg = ct_data.lung_seg_img\n\n        # 保存肺部分割结果\n        lung_seg_path = os.path.join(session_dir, 'lung_seg.npy')\n        np.save(lung_seg_path, lung_seg)\n\n        # 更新状态 - 肺部分割完成\n        state['progress'] = 40\n        state['message'] = '肺部分割完成，准备检测'\n        state['status'] = 'preprocessed'  # 在肺部分割完成时就设置为preprocessed\n        # 添加肺部分割路径\n        state['lung_seg_path'] = lung_seg_path\n        SESSION_DATA[session_id] = state\n        detector.update_session_state(session_id, state)\n\n        # 自动启动检测 - 传递正确的路径\n        if file_type == 'MHD/RAW':\n            # 对于MHD使用文件路径\n            detector.detect(session_info['files'][0], session_id, patient_id=None)\n        else:\n            # 对于DICOM使用目录路径\n            detector.detect(extract_dir, session_id, patient_id=None)\n\n        # 在保存肺部分割数据的时候，同时保存所有切片图像\n        save_lung_segmentation_slices(session_id, lung_seg)\n\n    except Exception as e:\n        logger.error(f\"预处理失败: {str(e)}\", exc_info=True)\n\n        # 更新状态为错误\n        if session_id in SESSION_DATA:\n            SESSION_DATA[session_id]['status'] = 'error'\n            SESSION_DATA[session_id]['message'] = f'预处理失败: {str(e)}'\n            detector.update_session_state(session_id, SESSION_DATA[session_id])\n\n\n@app.route('/api/detect', methods=['POST'])\ndef start_detection():\n    \"\"\"启动检测过程\"\"\"\n    try:\n        # 获取请求参数\n        data = request.json\n        session_id = data.get('session_id')\n        patient_id = data.get('patient_id', session_id)\n\n        if not session_id:\n            return jsonify({'success': False, 'error': '缺少会话ID'}), 400\n\n        # 获取会话文件夹和文件\n        session_folder = os.path.join(UPLOAD_FOLDER, session_id)\n        if not os.path.exists(session_folder):\n            return jsonify({'success': False, 'error': f'会话 {session_id} 不存在'}), 404\n\n        # 读取session_info.json获取正确的文件路径\n        session_info_path = os.path.join(session_folder, 'session_info.json')\n        if not os.path.exists(session_info_path):\n            return jsonify({'success': False, 'error': '会话信息文件不存在'}), 404\n\n        # 加载会话信息\n        with open(session_info_path, 'r') as f:\n            session_info = json.load(f)\n\n        # 根据文件类型选择正确的路径\n        file_type = session_info.get('file_type')\n        extract_dir = session_info.get('extract_dir')\n\n        if not file_type or not extract_dir:\n            return jsonify({'success': False, 'error': '会话信息不完整'}), 400\n\n        # 使用与start_preprocessing相同的逻辑\n        if file_type == 'MHD/RAW':\n            # 对于MHD使用文件路径\n            file_path = session_info['files'][0]\n        else:\n            # 对于DICOM使用目录路径\n            file_path = extract_dir\n\n        logger.info(f\"使用文件路径 {file_path} 开始检测\")\n\n        # 启动检测\n        if not detector.detect(file_path, session_id, patient_id):\n            return jsonify({'success': False, 'error': '检测启动失败，请确保模型已正确加载'}), 500\n\n        # 返回会话ID，客户端将使用此ID轮询进度\n        return jsonify({\n            'success': True,\n            'session_id': session_id,\n            'message': '检测已启动'\n        })\n\n    except Exception as e:\n        logger.error(f\"启动检测错误: {str(e)}\", exc_info=True)\n        return jsonify({'success': False, 'error': str(e)}), 500\n\n\n@app.route('/api/progress/<session_id>', methods=['GET'])\ndef get_progress(session_id):\n    \"\"\"获取检测进度\"\"\"\n    try:\n        # 获取会话状态 - 首先尝试从检测器获取\n        state = detector.get_session_state(session_id)\n\n        # 如果检测器中没有，则尝试从SESSION_DATA获取\n        if not state and session_id in SESSION_DATA:\n            state = SESSION_DATA[session_id]\n            logger.info(f\"从SESSION_DATA中获取会话 {session_id} 的状态\")\n\n        # 如果两处都没有，则返回错误\n        if not state:\n            return jsonify({'success': False, 'error': f'会话 {session_id} 不存在'}), 404\n\n        # 创建响应\n        response = {\n            'success': True,\n            'session_id': session_id,\n            'status': state['status'],\n            'progress': state['progress'],\n            'message': state['message'],\n            'started_at': state.get('started_at'),\n            'completed_at': state.get('completed_at')\n        }\n\n        # 如果检测完成，添加结节信息\n        if state['status'] == STATUS_COMPLETED and 'nodules' in state:\n            response['nodules_count'] = len(state['nodules'])\n\n        # 如果发生错误，添加错误信息\n        if state['status'] == STATUS_ERROR and 'error' in state:\n            response['error'] = state['error']\n\n        return jsonify(response)\n\n    except Exception as e:\n        logger.error(f\"获取进度错误: {str(e)}\", exc_info=True)\n        return jsonify({'success': False, 'error': str(e)}), 500\n\n\n@app.route('/api/results/<session_id>', methods=['GET'])\ndef get_results(session_id):\n    \"\"\"获取检测结果\"\"\"\n    try:\n        # 获取会话状态\n        state = detector.get_session_state(session_id)\n        if not state:\n            return jsonify({'success': False, 'error': f'会话 {session_id} 不存在'}), 404\n        print(\"结节api/result 返回状态\", state['status'])\n        # 检查检测是否完成\n        if state['status'] != STATUS_COMPLETED:\n            return jsonify({\n                'success': False,\n                'error': '检测尚未完成',\n                'status': state['status'],\n                'progress': state['progress']\n            }), 400\n\n        # 检查是否有结节数据\n        if 'nodules' not in state or not state['nodules']:\n            print('nodule 不在最终结果里面')\n            return jsonify({\n                'success': True,\n                'session_id': session_id,\n                'nodules_count': 0,\n                'nodules': [],\n                'message': '未检测到结节'\n            })\n\n        # 返回结节信息\n        # 注意：为了减少传输数据量，这里不返回完整的立方体数据\n        nodules_info = []\n        for nodule in state['nodules']:\n            print('尝试在返回结果增加每个结节信息')\n            # 创建不包含立方体数据的结节信息副本\n            nodule_info = {\n                'id': nodule['id'],\n                'voxel_coords': nodule['voxel_coords'],\n                'world_coords': nodule['world_coords'],\n                'diameter_mm': nodule['diameter_mm'],\n                'probability': nodule['probability']\n            }\n            nodules_info.append(nodule_info)\n            lung_seg_path = os.path.join(UPLOAD_FOLDER, session_id, 'lung_seg.npy')\n            lung_seg_img = np.load(lung_seg_path)\n            print('开始保存 结节的bbox...')\n            save_slice_box(session_id, lung_seg_img, nodule['voxel_coords'], radius=32)\n            print('保存一个结节的bbox 完成。。。')\n\n        return jsonify({\n            'success': True,\n            'session_id': session_id,\n            'nodules_count': len(nodules_info),\n            'nodules': nodules_info,\n            'message': f'检测完成，共发现 {len(nodules_info)} 个结节'\n        })\n\n    except Exception as e:\n        logger.error(f\"获取结果错误: {str(e)}\", exc_info=True)\n        return jsonify({'success': False, 'error': str(e)}), 500\n\n\n@app.route('/api/nodule/<session_id>/<int:nodule_id>', methods=['GET'])\ndef get_nodule_data(session_id, nodule_id):\n    try:\n        # 检查会话是否存在\n        if session_id not in SESSION_DATA:\n            return jsonify({\"success\": False, \"error\": \"会话不存在\"}), 404\n\n        # 获取会话数据\n        session = SESSION_DATA[session_id]\n        print('返回 api nodule 结果状态\\t', session['status'])\n        # 检查是否已完成检测\n        if session['status'] not in ['preprocessed', 'completed']:\n            return jsonify({\"success\": False, \"error\": \"CT数据检测未完成\"}), 400\n        # 获取检测结果\n        results = session.get('results', {})\n        nodules = results.get('nodules', [])\n        # 查找指定的结节\n        target_nodule = None\n        for nodule in nodules:\n            if nodule.get('id') == nodule_id:\n                target_nodule = nodule\n                break\n        if not target_nodule:\n            return jsonify({\"success\": False, \"error\": \"找不到指定的结节\"}), 404\n        # 返回结节数据\n        return jsonify({\n            \"success\": True,\n            \"nodule\": target_nodule\n        })\n\n    except Exception as e:\n        app.logger.error(f\"获取结节数据时出错: {str(e)}\")\n        return jsonify({\"success\": False, \"error\": f\"获取结节数据时出错: {str(e)}\"}), 500\n\n\n# 添加新函数: 在检测完成后为所有结节生成并保存图像\ndef generate_nodule_images(session_id, nodules, lung_img):\n    \"\"\"\n    为所有结节生成并保存切片图像\n\n    Args:\n        session_id: 会话ID\n        nodules: 结节列表\n        lung_img: 肺部分割数据\n    \"\"\"\n    try:\n        # 确保结节图像目录存在\n        nodule_images_dir = os.path.join(UPLOAD_FOLDER, session_id, 'nodule_images')\n        os.makedirs(nodule_images_dir, exist_ok=True)\n        app.logger.info(f\"开始为 {len(nodules)} 个结节生成切片图像\")\n\n        # 打印肺部图像数据统计信息以进行调试\n        app.logger.info(f\"肺部图像数据统计: 形状={lung_img.shape}, 类型={lung_img.dtype}, \"\n                        f\"最小值={lung_img.min()}, 最大值={lung_img.max()}, \"\n                        f\"平均值={lung_img.mean()}, 中位数={np.median(lung_img)}\")\n\n        # 检查肺部图像数据是否全为0或接近于0\n        if np.max(lung_img) < 0.01:\n            app.logger.warning(\"警告: 肺部图像数据几乎全为零，可能导致结节图像显示为黑色\")\n            # 尝试自动增强对比度 - 将较小的值映射到0，将较大的值映射到255\n            lung_img = (lung_img * 100).clip(0, 1)\n            app.logger.info(\n                f\"增强对比度后统计: 最小值={lung_img.min()}, 最大值={lung_img.max()}, 平均值={lung_img.mean()}\")\n\n        # 处理每个结节\n        for nodule in nodules:\n            nodule_id = nodule.get('id')\n            # 获取结节中心坐标\n            voxel_coords = nodule.get('voxel_coords', [0, 0, 0])\n            app.logger.info(f\"处理结节 {nodule_id}, 原始坐标 (xyz格式): {voxel_coords}\")\n            # 将坐标转换为整数，注意从xyz转换为zyx顺序\n            x, y, z = [int(coord) for coord in voxel_coords]\n            # 交换坐标顺序为z,y,x以匹配lung_img的顺序\n            center_voxel_zyx = [z, y, x]\n            app.logger.info(f\"转换后坐标 (zyx格式): {center_voxel_zyx}\")\n            # 检查坐标是否在肺部图像范围内\n            if (center_voxel_zyx[0] < 0 or center_voxel_zyx[0] >= lung_img.shape[0] or\n                    center_voxel_zyx[1] < 0 or center_voxel_zyx[1] >= lung_img.shape[1] or\n                    center_voxel_zyx[2] < 0 or center_voxel_zyx[2] >= lung_img.shape[2]):\n                app.logger.error(\n                    f\"结节 {nodule_id} 坐标超出肺部图像范围: {center_voxel_zyx}, 肺部图像尺寸: {lung_img.shape}\")\n                continue\n            # 定义立方体的半尺寸\n            half_size = 16  # CUBE_SIZE/2 = 16\n            # 提取立方体数据\n            z_min = max(0, center_voxel_zyx[0] - half_size)\n            y_min = max(0, center_voxel_zyx[1] - half_size)\n            x_min = max(0, center_voxel_zyx[2] - half_size)\n\n            z_max = min(lung_img.shape[0], center_voxel_zyx[0] + half_size)\n            y_max = min(lung_img.shape[1], center_voxel_zyx[1] + half_size)\n            x_max = min(lung_img.shape[2], center_voxel_zyx[2] + half_size)\n\n            # 检查提取的范围是否有效\n            if z_min >= z_max or y_min >= y_max or x_min >= x_max:\n                app.logger.error(\n                    f\"结节 {nodule_id} 提取的范围无效: z({z_min}-{z_max}), y({y_min}-{y_max}), x({x_min}-{x_max})\")\n                continue\n\n            # 提取子体积\n            cube = lung_img[z_min:z_max, y_min:y_max, x_min:x_max]\n            # 检查立方体数据是否为空\n            if cube.size == 0:\n                app.logger.error(f\"结节 {nodule_id} 提取的立方体数据为空\")\n                continue\n\n            app.logger.info(f\"成功提取结节 {nodule_id} 立方体数据, 形状: {cube.shape}\")\n            app.logger.info(f\"立方体数据统计: 最小值={cube.min()}, 最大值={cube.max()}, 平均值={cube.mean()}\")\n\n            # 为每个平面类型生成图像\n            for plane_type in ['axial', 'coronal', 'sagittal']:\n                try:\n                    # 根据平面类型获取中心切片\n                    if plane_type == 'axial':  # Z轴切片\n                        if cube.shape[0] == 0:\n                            app.logger.error(f\"结节 {nodule_id} 在轴向切片上的维度为0\")\n                            continue\n                        center_index = min(cube.shape[0] // 2, cube.shape[0] - 1)\n                        slice_data = cube[center_index, :, :]\n                    elif plane_type == 'coronal':  # Y轴切片\n                        if cube.shape[1] == 0:\n                            app.logger.error(f\"结节 {nodule_id} 在冠状切片上的维度为0\")\n                            continue\n                        center_index = min(cube.shape[1] // 2, cube.shape[1] - 1)\n                        slice_data = cube[:, center_index, :]\n                    elif plane_type == 'sagittal':  # X轴切片\n                        if cube.shape[2] == 0:\n                            app.logger.error(f\"结节 {nodule_id} 在矢状切片上的维度为0\")\n                            continue\n                        center_index = min(cube.shape[2] // 2, cube.shape[2] - 1)\n                        slice_data = cube[:, :, center_index]\n\n                    # 检查切片数据是否为空\n                    if slice_data.size == 0:\n                        app.logger.error(f\"结节 {nodule_id} 在{plane_type}切片上的数据为空\")\n                        continue\n\n                    app.logger.info(\n                        f\"切片数据统计: 最小值={slice_data.min()}, 最大值={slice_data.max()}, 平均值={slice_data.mean()}\")\n\n                    # 增强对比度\n                    # 如果所有值都很小（接近0），进行强化对比度处理\n                    if slice_data.max() < 0.1:\n                        # 将数据放大100倍以增强可见度\n                        slice_data = np.clip(slice_data * 100, 0, 1)\n                        app.logger.info(f\"对比度增强后: 最小值={slice_data.min()}, 最大值={slice_data.max()}\")\n                    # 归一化数据\n                    if np.all(slice_data == 0):  # 检查是否全为0\n                        app.logger.warning(f\"结节切片数据全为0\")\n                        normalized_slice = np.zeros(slice_data.shape, dtype=np.uint8)\n                    elif slice_data.min() >= 0 and slice_data.max() <= 1:\n                        # 0-1范围数据，转为0-255\n                        normalized_slice = (slice_data * 255).astype(np.uint8)\n                    elif slice_data.min() >= 0 and slice_data.max() <= 255:\n                        normalized_slice = slice_data.astype(np.uint8)\n                    else:\n                        min_val = slice_data.min()\n                        max_val = slice_data.max()\n                        if max_val > min_val:\n                            normalized_slice = ((slice_data - min_val) / (max_val - min_val) * 255).astype(np.uint8)\n                        else:\n                            normalized_slice = np.zeros_like(slice_data, dtype=np.uint8)\n\n                    app.logger.info(\n                        f\"归一化后数据统计: 最小值={normalized_slice.min()}, 最大值={normalized_slice.max()}, 平均值={normalized_slice.mean()}\")\n\n                    # 应用额外的图像增强（如果需要）\n                    if normalized_slice.max() < 50:  # 如果最大值仍然很小\n                        # 应用CLAHE（对比度受限的自适应直方图均衡化）\n                        app.logger.info(\"应用CLAHE增强对比度\")\n                        try:\n                            clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8, 8))\n                            normalized_slice = clahe.apply(normalized_slice)\n                        except ImportError:\n                            # 如果没有cv2，使用简单的线性拉伸\n                            normalized_slice = np.clip(normalized_slice * 5, 0, 255).astype(np.uint8)\n\n                    # 创建图像\n                    plt.figure(figsize=(6, 6))\n                    plt.imshow(normalized_slice, cmap='gray', vmin=0, vmax=255)\n\n                    # 添加标题\n                    view_names = {\n                        'axial': '轴向视图 (XY)',\n                        'coronal': '冠状视图 (XZ)',\n                        'sagittal': '矢状视图 (YZ)'\n                    }\n                    plt.title(f\"结节 {nodule_id} - {view_names.get(plane_type, plane_type)}\")\n\n                    # 添加中心标记\n                    center_y, center_x = normalized_slice.shape[0] // 2, normalized_slice.shape[1] // 2\n                    plt.plot(center_x, center_y, 'r+', markersize=10)\n\n                    # 添加直径标记\n                    diameter = nodule.get('diameter_mm', 10)\n                    radius_pixels = diameter / 2\n                    circle = plt.Circle((center_x, center_y), radius_pixels,\n                                        color='r', fill=False, linestyle='--')\n                    plt.gca().add_patch(circle)\n\n                    # 添加颜色条以直观显示像素值\n                    plt.colorbar(label='像素值')\n\n                    # 关闭坐标轴\n                    plt.axis('off')\n                    plt.tight_layout()\n\n                    # 保存图像到文件\n                    img_filename = f\"nodule_{nodule_id}_{plane_type}.png\"\n                    img_path = os.path.join(nodule_images_dir, img_filename)\n\n                    # 保存图像\n                    plt.savefig(img_path, format='png', dpi=100)\n                    plt.close('all')  # 确保关闭所有图形\n\n                    app.logger.info(f\"已保存结节图像: {img_path}\")\n                except Exception as slice_error:\n                    app.logger.error(f\"生成结节 {nodule_id} 的 {plane_type} 图像时出错: {str(slice_error)}\",\n                                     exc_info=True)\n\n        app.logger.info(f\"结节图像生成完成\")\n        return True\n\n    except Exception as e:\n        app.logger.error(f\"生成结节图像时出错: {str(e)}\", exc_info=True)\n        return False\n\n\n# 修改检测完成时的代码，在检测完成后生成结节图像\ndef update_session_state(session_id, state):\n    \"\"\"更新会话状态\"\"\"\n    if session_id in SESSION_DATA:\n        SESSION_DATA[session_id]['status'] = state\n\n        # 如果检测已完成，生成所有结节的图像\n        if state == 'completed' and 'results' in SESSION_DATA[session_id]:\n            nodules = SESSION_DATA[session_id]['results'].get('nodules', [])\n            if nodules:\n                app.logger.info(f\"检测完成，开始生成结节图像\")\n                # 加载肺部分割数据用于生成图像\n                try:\n                    lung_seg_path = os.path.join(UPLOAD_FOLDER, session_id, 'lung_seg.npy')\n                    if os.path.exists(lung_seg_path):\n                        lung_img = np.load(lung_seg_path)\n                        # 生成结节图像\n                        generate_nodule_images(session_id, nodules, lung_img)\n                    else:\n                        app.logger.error(f\"无法生成结节图像: 肺部分割数据不存在\")\n                except Exception as e:\n                    app.logger.error(f\"生成结节图像时出错: {str(e)}\", exc_info=True)\n    else:\n        # 创建新会话\n        SESSION_DATA[session_id] = {\n            'status': state,\n            'progress': 0\n        }\n\ndef save_slice_box(session_id,lung_seg,voxel_coords,radius = 32):\n    print(\"传入的体素是\\t\", voxel_coords)\n    x, y, z = [int(coord) for coord in voxel_coords]\n    from_z,end_z = z- int(radius/2), z + int(radius/2)\n    from_x,end_x = x - int(radius/2), x + int(radius/2)\n    from_y, end_y = y - int(radius / 2), y + int(radius / 2)\n    slices_dir = os.path.join(UPLOAD_FOLDER, session_id, 'lung_slices')\n    for z_index in range(from_z, end_z + 1):\n        slice_data = lung_seg[z_index, :, :]\n        # 归一化数据\n        normalized_slice = (slice_data * 255).astype(np.uint8)\n        # 在图像上绘制红色边界框\n        # 将灰度图转换为彩色图像，以便绘制彩色边界框\n        color_slice = cv2.cvtColor(normalized_slice, cv2.COLOR_GRAY2BGR)\n        # 确保边界框坐标在图像范围内\n        from_y_safe = max(0, from_y)\n        end_y_safe = min(normalized_slice.shape[0], end_y)\n        from_x_safe = max(0, from_x)\n        end_x_safe = min(normalized_slice.shape[1], end_x)\n        # 绘制红色边界框，线宽为2\n        cv2.rectangle(color_slice, (from_x_safe, from_y_safe), (end_x_safe, end_y_safe), (0, 0, 255), 2)\n        # 使用绘制了边界框的彩色图像替换原来的灰度图像\n        normalized_slice = color_slice\n        # 创建图像文件名和路径\n        img_filename = f\"z_slice_{z_index:04d}.png\"\n        img_path = os.path.join(slices_dir, img_filename)\n        # 保存图像\n        cv2.imwrite(img_path, normalized_slice)\n        print(\"保存一个结节的 bbox 到\\t\", img_path)\n\n# 在保存肺部分割数据的时候，同时保存所有切片图像\ndef save_lung_segmentation_slices(session_id, lung_seg):\n    \"\"\"\n    保存肺部分割的所有切片图像\n\n    Args:\n        session_id: 会话ID\n        lung_seg: 肺部分割数据 (3D数组)\n    \"\"\"\n    try:\n        # 创建切片图像目录\n        slices_dir = os.path.join(UPLOAD_FOLDER, session_id, 'lung_slices')\n        os.makedirs(slices_dir, exist_ok=True)\n        app.logger.info(f\"开始保存肺部切片图像，形状: {lung_seg.shape}\")\n        # 获取各轴切片数量\n        z_slices = lung_seg.shape[0]\n        y_slices = lung_seg.shape[1]\n        x_slices = lung_seg.shape[2]\n\n        # 创建索引文件，用于前端加载\n        slice_info = {\n            \"dimensions\": lung_seg.shape,\n            \"z_slices\": z_slices,\n            \"y_slices\": y_slices,\n            \"x_slices\": x_slices,\n            \"z_axis\": [],\n            \"y_axis\": [],\n            \"x_axis\": []\n        }\n        # 保存Z轴切片 (横断面)\n        app.logger.info(f\"正在保存Z轴切片，共{z_slices}个...\")\n        for z in range(z_slices):\n            # 获取切片\n            slice_data = lung_seg[z, :, :]\n            # 归一化数据\n            normalized_slice = (slice_data * 255).astype(np.uint8)\n            # 创建图像文件名和路径\n            img_filename = f\"z_slice_{z:04d}.png\"\n            img_path = os.path.join(slices_dir, img_filename)\n            # 保存图像\n            cv2.imwrite(img_path, normalized_slice)\n            # 添加切片信息到索引\n            slice_info[\"z_axis\"].append({\n                \"index\": z,\n                \"filename\": img_filename,\n                \"path\": f\"/api/lung_slice/{session_id}/z/{z}\"\n            })\n            \n            # 每保存10个切片打印一次进度\n            if z % 10 == 0 or z == z_slices - 1:\n                app.logger.info(f\"已保存Z轴切片 {z + 1}/{z_slices}\")\n        # 保存索引文件\n        index_path = os.path.join(slices_dir, 'slices_index.json')\n        with open(index_path, 'w') as f:\n            json.dump(slice_info, f)\n\n        app.logger.info(f\"肺部切片图像保存完成，共 {z_slices} 个Z轴切片, {y_slices} 个Y轴切片, {x_slices} 个X轴切片\")\n        return True\n\n    except Exception as e:\n        app.logger.error(f\"保存肺部切片图像时出错: {str(e)}\", exc_info=True)\n        return False\n\n# 添加API端点获取肺部切片图像\n@app.route('/api/lung_slice/<session_id>/z/<int:slice_index>', methods=['GET'])\ndef get_lung_z_slice(session_id, slice_index):\n    \"\"\"获取肺部的Z轴切片图像\"\"\"\n    try:\n        app.logger.info(f\"请求Z轴切片 - 会话ID: {session_id}, 切片索引: {slice_index}\")\n        # 检查会话是否存在\n        if session_id not in SESSION_DATA:\n            app.logger.error(f\"会话不存在: {session_id}\")\n            return jsonify({\"success\": False, \"error\": \"会话不存在\"}), 404\n        \n        # 图像文件路径\n        img_filename = f\"z_slice_{slice_index:04d}.png\"\n        img_path = os.path.join(UPLOAD_FOLDER, session_id, 'lung_slices', img_filename)\n        app.logger.info(f\"切片图像路径: {img_path}\")\n        \n        # 检查文件是否存在\n        if os.path.exists(img_path):\n            app.logger.info(f\"切片图像文件已存在，直接返回: {img_path}\")\n            return send_file(img_path, mimetype='image/png')\n\n    except Exception as e:\n        app.logger.error(f\"获取肺部Z轴切片时出错: {str(e)}\", exc_info=True)\n        return jsonify({\"success\": False, \"error\": f\"获取肺部Z轴切片时出错: {str(e)}\"}), 500\n\n# 添加API端点获取肺部切片信息\n@app.route('/api/lung_slices_info/<session_id>', methods=['GET'])\ndef get_lung_slices_info(session_id):\n    \"\"\"获取肺部切片信息\"\"\"\n    print('/api/lung_slices_info 会话内容\\t', SESSION_DATA, session_id)\n    # 检查会话是否存在\n    if session_id not in SESSION_DATA:\n        return jsonify({\"success\": False, \"error\": \"会话不存在\"}), 404\n    # 索引文件路径\n    index_path = os.path.join(UPLOAD_FOLDER, session_id, 'lung_slices', 'slices_index.json')\n    print('索引文件路径\\t', index_path)\n    # 检查索引文件是否存在\n    if os.path.exists(index_path):\n        with open(index_path, 'r') as f:\n            slices_info = json.load(f)\n        return jsonify({\"success\": True, \"slices_info\": slices_info})\n    print('索引文件不存在，重新加载 npy 文件\\t')\n    # 如果索引文件不存在，尝试从肺部分割数据创建基本信息\n    lung_seg_path = os.path.join(UPLOAD_FOLDER, session_id, 'lung_seg.npy')\n    if not os.path.exists(lung_seg_path):\n        return jsonify({\"success\": False, \"error\": \"肺部分割数据不存在\"}), 404\n\n    # 加载数据并创建基本信息\n    lung_seg = np.load(lung_seg_path)\n\n    # 获取各轴切片数量\n    z_slices = lung_seg.shape[0]\n    y_slices = lung_seg.shape[1]\n    x_slices = lung_seg.shape[2]\n\n    # 创建与save_lung_segmentation_slices函数相同格式的切片信息\n    slices_info = {\n        \"dimensions\": lung_seg.shape,\n        \"z_slices\": z_slices,\n        \"y_slices\": y_slices,\n        \"x_slices\": x_slices,\n        \"z_axis\": [],\n        \"y_axis\": [],\n        \"x_axis\": []\n    }\n    # 添加Z轴切片信息\n    for z in range(z_slices):\n        slices_info[\"z_axis\"].append({\n            \"index\": z,\n            \"filename\": f\"z_slice_{z:04d}.png\",\n            \"path\": f\"/api/lung_slice/{session_id}/z/{z}\"\n        })\n    return jsonify({\"success\": True, \"slices_info\": slices_info})\n\n# 启动服务器\nif __name__ == '__main__':\n    # 确保目录存在\n    os.makedirs(UPLOAD_FOLDER, exist_ok=True)\n    # 启用自动检测\n    app.config['AUTO_DETECT'] = True\n    # 启动应用\n    app.run(host='0.0.0.0', port=5000, debug=True)\n"
  },
  {
    "path": "deploy/backend/dataclass/CTData.py",
    "content": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\nimport numpy as np\nimport SimpleITK as sitk\nfrom scipy import ndimage\nfrom enum import Enum\nimport matplotlib.pyplot as plt\nfrom util.dicom_util import load_dicom_slices, get_pixels_hu, get_dicom_thickness\nfrom util.seg_util import get_segmented_lungs, normalize_hu_values\n\n\nclass CTFormat(Enum):\n    DICOM = 1\n    MHD = 2\n    UNKNOWN = 3\n\nclass CTData:\n    \"\"\"\n    统一的CT数据类，用于处理不同格式的CT图像数据\n    支持DICOM和MHD格式的加载、处理和分析\n    \"\"\"\n    def __init__(self):\n        # 基本属性\n        self.pixel_data = None  # 像素数据，3D体素数组 (z,y,x)\n        self.lung_seg_img = None    # 单独抽取肺部CT图像数据\n        self.lung_seg_mask = None   # 肺部CT的掩码\n        self.origin = None  # 坐标原点 (x,y,z)，单位为mm\n        self.spacing = None  # 体素间距 (x,y,z)，单位为mm\n        self.orientation = None  # 方向矩阵\n        self.z_axis_flip = False    # z 轴是否是翻转的\n        self.size = None  # 图像尺寸 (z,y,x)\n        self.data_format = None  # 数据格式(DICOM/MHD)\n        self.metadata = {}  # 其他元数据信息\n        self.hu_converted = False  # 是否已转换为HU值\n        self.preprocessed = False   # 数据是否已经处理过\n\n    @classmethod\n    def from_dicom(cls, dicom_path):\n        \"\"\"\n        从DICOM文件夹加载CT数据\n\n        Args:\n            dicom_path: DICOM文件夹路径\n\n        Returns:\n            CTData对象\n        \"\"\"\n        ct_data = cls()\n        ct_data.data_format = CTFormat.DICOM\n        slices = load_dicom_slices(dicom_path)\n        ct_data.pixel_data = get_pixels_hu(slices)\n        ct_data.z_axis_flip = slices[1].ImagePositionPatient[2] > slices[0].ImagePositionPatient[2]\n        ct_data.hu_converted = True\n        slice_thickness = get_dicom_thickness(slices)\n        # 设置像素间距\n        try:\n            ct_data.spacing = [\n                float(slices[0].PixelSpacing[0]),\n                float(slices[0].PixelSpacing[1]),\n                float(slice_thickness)\n            ]\n        except:\n            print(\"警告: 无法获取像素间距，使用默认值[1.0, 1.0, 1.0]\")\n            ct_data.spacing = [1.0, 1.0, 1.0]\n        # 设置原点\n        try:\n            ct_data.origin = [\n                float(slices[0].ImagePositionPatient[0]),\n                float(slices[0].ImagePositionPatient[1]),\n                float(slices[0].ImagePositionPatient[2])\n            ]\n        except:\n            print(\"警告: 无法获取坐标原点，使用默认值[0.0, 0.0, 0.0]\")\n            ct_data.origin = [0.0, 0.0, 0.0]\n        # 设置尺寸\n        ct_data.size = ct_data.pixel_data.shape\n        return ct_data\n\n    @classmethod\n    def from_mhd(cls, mhd_path):\n        \"\"\"\n        从MHD/RAW文件加载CT数据\n\n        Args:\n            mhd_path: MHD文件路径\n\n        Returns:\n            CTData对象\n        \"\"\"\n        ct_data = cls()\n        ct_data.data_format = CTFormat.MHD\n        try:\n            # 使用SimpleITK加载MHD文件\n            itk_img = sitk.ReadImage(mhd_path)\n            # 获取像素数据 (注意SimpleITK返回的数组顺序为z,y,x)\n            ct_data.pixel_data = sitk.GetArrayFromImage(itk_img)\n            # LUNA16的MHD数据已经是HU值\n            ct_data.hu_converted = True\n            # 获取原点和体素间距\n            ct_data.origin = list(itk_img.GetOrigin())  # (x,y,z)\n            ct_data.spacing = list(itk_img.GetSpacing())  # (x,y,z)\n            # 获取尺寸\n            ct_data.size = ct_data.pixel_data.shape\n            # 提取方向信息\n            ct_data.orientation = itk_img.GetDirection()\n            ct_data.z_axis_flip = False\n        except Exception as e:\n            raise ValueError(f\"加载MHD文件时出错: {e}\")\n\n        return ct_data\n\n    def convert_to_hu(self):\n        \"\"\"\n        将像素值转换为HU值（如果尚未转换）\n        \"\"\"\n        if self.hu_converted:\n            print(\"数据已经是HU值格式\")\n            return\n\n        if self.data_format == CTFormat.DICOM:\n            # 已在from_dicom中处理\n            self.hu_converted = True\n        elif self.data_format == CTFormat.MHD:\n            # LUNA16的MHD数据已经是HU值\n            self.hu_converted = True\n        else:\n            raise ValueError(\"未知数据格式，无法转换为HU值\")\n\n\n    def resample_pixel(self, new_spacing=[1, 1, 1]):\n        \"\"\"\n        将CT体素重采样为指定间距\n\n        Args:\n            new_spacing: 目标体素间距 [x, y, z]\n\n        Returns:\n            重采样后的CTData对象\n        \"\"\"\n        # 确保数据已转换为HU值\n        if not self.hu_converted:\n            self.convert_to_hu()\n        # 为了符合scipy.ndimage的要求，将spacing和pixel_data的顺序调整为[z,y,x]\n        spacing_zyx = [self.spacing[2], self.spacing[1], self.spacing[0]]\n        new_spacing_zyx = [new_spacing[2], new_spacing[1], new_spacing[0]]\n        # 计算新尺寸\n        resize_factor = np.array(spacing_zyx) / np.array(new_spacing_zyx)\n        new_shape = np.round(np.array(self.pixel_data.shape) * resize_factor)\n        # 计算实际重采样因子\n        real_resize = new_shape / np.array(self.pixel_data.shape)\n        # 执行重采样 - 使用三线性插值\n        resampled_data = ndimage.zoom(self.pixel_data, real_resize, order=1)\n        # 创建新的CTData对象\n        resampled_ct = CTData()\n        resampled_ct.pixel_data = resampled_data\n        resampled_ct.spacing = new_spacing\n        resampled_ct.origin = self.origin\n        resampled_ct.orientation = self.orientation\n        resampled_ct.size = resampled_data.shape\n        resampled_ct.data_format = self.data_format\n        resampled_ct.hu_converted = self.hu_converted\n        resampled_ct.preprocessed = self.preprocessed\n        return resampled_ct\n\n    def filter_lung_img_mask(self):\n        \"\"\"\n            只保留肺部区域像素，并且归一化到 0-1 之间\n        :return:\n        \"\"\"\n        pixel_data = self.pixel_data.copy()\n        seg_img = []\n        seg_mask = []\n        for index in range(pixel_data.shape[0]):\n            one_seg_img ,one_seg_mask = get_segmented_lungs(pixel_data[index])\n            one_seg_img = normalize_hu_values(one_seg_img)\n            seg_img.append(one_seg_img)\n            seg_mask.append(one_seg_mask)\n        self.lung_seg_img = np.array(seg_img)\n        self.lung_seg_mask = np.array(seg_mask)\n    def world_to_voxel(self, world_coord):\n        \"\"\"\n        将世界坐标(mm)转换为体素坐标\n\n        Args:\n            world_coord: 世界坐标 [x,y,z] (mm)\n\n        Returns:\n            体素坐标 [x,y,z]\n        \"\"\"\n        voxel_coord = np.zeros(3, dtype=int)\n        for i in range(3):\n            voxel_coord[i] = int(round((world_coord[i] - self.origin[i]) / self.spacing[i]))\n\n        return voxel_coord\n\n    def voxel_to_world(self, voxel_coord):\n        \"\"\"\n        将体素坐标转换为世界坐标(mm)\n\n        Args:\n            voxel_coord: 体素坐标 [x,y,z]\n\n        Returns:\n            世界坐标 [x,y,z] (mm)\n        \"\"\"\n        world_coord = np.zeros(3, dtype=float)\n        for i in range(3):\n            world_coord[i] = voxel_coord[i] * self.spacing[i] + self.origin[i]\n\n        return world_coord\n\n    def extract_cube(self, center_world_mm, size_mm,if_fixed_radius = False):\n        \"\"\"\n        提取指定中心点和大小的立方体区域\n\n        Args:\n            center_world_mm:    立方体中心的世界坐标 [x,y,z] (mm)\n            size_mm:            立方体在世界坐标系的大小(mm)，可以是数值或[x,y,z]形式\n            if_fixed_radius:    是否为固定半径。默认是False(即不不是固定的，就说明每个结节半径都不一样，按照标注文件半径抽取)\n\n        Returns:\n            立方体像素数据\n        \"\"\"\n        # 确保数据已加载\n        if self.pixel_data is None:\n            raise ValueError(\"未加载数据\")\n        if self.lung_seg_img is None:\n            print(\"肺部区域数据没有分割，现在开始分割..\")\n            self.filter_lung_img_mask()\n        # 将世界坐标转换为体素坐标（注意：SimpleITK数组顺序为z,y,x）\n        center_voxel = self.world_to_voxel(center_world_mm)\n        # 交换坐标顺序为z,y,x以匹配pixel_data\n        center_voxel_zyx = [center_voxel[2], center_voxel[1], center_voxel[0]]\n        # 如果使用固定半径，那么只需要中心坐标即可，此时size_mm 就是像素半径了，直接从 lung_seg_img 按照像素半径抽取即可\n        if if_fixed_radius:\n            half_size = [int(size_mm/2), int(size_mm/2), int(size_mm/2)]\n        else:\n            # 计算立方体边长(体素数) [luna2016 的标注数据中每个结节半径不同，按照标注抽取的结节大小不一，最好使用固定半径]\n            size_voxel = [int(size_mm / self.spacing[2]),\n                          int(size_mm / self.spacing[1]),\n                          int(size_mm / self.spacing[0])]\n            # 计算立方体边界\n            half_size = [s // 2 for s in size_voxel]\n        # 提取立方体数据\n        z_min = max(0, center_voxel_zyx[0] - half_size[0])\n        y_min = max(0, center_voxel_zyx[1] - half_size[1])\n        x_min = max(0, center_voxel_zyx[2] - half_size[2])\n\n        z_max = min(self.lung_seg_img.shape[0], center_voxel_zyx[0] + half_size[0])\n        y_max = min(self.lung_seg_img.shape[1], center_voxel_zyx[1] + half_size[1])\n        x_max = min(self.lung_seg_img.shape[2], center_voxel_zyx[2] + half_size[2])\n        # 提取子体积\n        cube = self.lung_seg_img[z_min:z_max, y_min:y_max, x_min:x_max]\n        return cube\n\n    def visualize_slice(self, slice_idx=None, axis=0, show_lung_only=False):\n        \"\"\"\n            可视化单个切片\n        Args:\n            slice_idx:          切片索引，如果为None则取中心切片\n            axis:               沿哪个轴切片 (0=z, 1=y, 2=x)\n            show_lung_only:     是否只显示肺部，其他区域都作为背景黑色\n        \"\"\"\n        # 确保数据已加载\n        if self.pixel_data is None:\n            raise ValueError(\"未加载数据\")\n        # 确定切片索引\n        if slice_idx is None:\n            slice_idx = self.pixel_data.shape[axis] // 2\n        # 提取切片数据\n        if show_lung_only:\n            if axis == 0:  # z轴\n                slice_data = self.lung_seg_img[slice_idx, :, :]\n            elif axis == 1:  # y轴\n                slice_data = self.lung_seg_img[:, slice_idx, :]\n            else :  # x轴\n                slice_data = self.lung_seg_img[:, :, slice_idx]\n        else:\n            if axis == 0:  # z轴\n                slice_data = self.pixel_data[slice_idx, :, :]\n            elif axis == 1:  # y轴\n                slice_data = self.pixel_data[:, slice_idx, :]\n            else:  # x轴\n                slice_data = self.pixel_data[:, :, slice_idx]\n        # 创建图像\n        plt.figure(figsize=(10, 8))\n        # 仅显示图像\n        plt.imshow(slice_data, cmap='gray')\n        # 设置标题\n        axis_name = ['z', 'y', 'x'][axis]\n        title = f\"切片 {slice_idx} (沿{axis_name}轴)\"\n        plt.title(title)\n        plt.colorbar(label='像素值')\n        plt.axis('off')\n        plt.tight_layout()\n        plt.show()\n\n    def visualize_nodule(self, coord_x,coord_y, coord_z, diameter):\n        \"\"\"\n             结节可视化\n        :param coord_x:\n        :param coord_y:\n        :param coord_z:\n        :param diameter:\n        :return:\n        \"\"\"\n        # 提取结节立方体\n        cube_size = max(32, int(diameter * 1.5))  # 确保立方体足够大\n        cube = self.extract_cube([coord_x, coord_y, coord_z], cube_size)\n        # 转换为体素坐标\n        voxel_coord = self.world_to_voxel([coord_x, coord_y, coord_z])\n        # 显示三个正交面\n        fig, axes = plt.subplots(1, 3, figsize=(15, 5))\n        # 提取中心切片\n        center_z = cube.shape[0] // 2\n        center_y = cube.shape[1] // 2\n        center_x = cube.shape[2] // 2\n        # 绘制三个正交面\n        axes[0].imshow(cube[center_z, :, :], cmap='gray')\n        axes[0].set_title(f'轴向视图 (z={center_z})')\n        axes[0].axis('off')\n        axes[1].imshow(cube[:, center_y, :], cmap='gray')\n        axes[1].set_title(f'冠状位视图 (y={center_y})')\n        axes[1].axis('off')\n        axes[2].imshow(cube[:, :, center_x], cmap='gray')\n        axes[2].set_title(f'矢状位视图 (x={center_x})')\n        axes[2].axis('off')\n        fig.suptitle(f\"结节- 位置: ({coord_x:.1f}, {coord_y:.1f}, {coord_z:.1f})mm, \" +\n                     f\"直径: {diameter:.1f}mm,\", fontsize=14)\n        plt.tight_layout()\n        plt.show()\n\n    def save_as_nifti(self, output_path):\n        \"\"\"\n        将CT数据保存为NIfTI格式\n\n        Args:\n            output_path: 输出文件路径\n        \"\"\"\n        # 确保数据已加载\n        if self.pixel_data is None:\n            raise ValueError(\"未加载数据\")\n\n        # 创建SimpleITK图像\n        # 注意：SimpleITK的数组顺序为z,y,x\n        img = sitk.GetImageFromArray(self.pixel_data)\n        img.SetOrigin(self.origin)\n        img.SetSpacing(self.spacing)\n\n        if self.orientation is not None:\n            img.SetDirection(self.orientation)\n        # 保存为NIfTI格式\n        sitk.WriteImage(img, output_path)\n        print(f\"已保存为NIfTI格式: {output_path}\")\n\n"
  },
  {
    "path": "deploy/backend/dataclass/NoduleCube.py",
    "content": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\nimport os,torch\nimport numpy as np\nimport cv2\nfrom typing import  Optional\nfrom dataclasses import dataclass\nimport matplotlib.pyplot as plt\nfrom scipy import ndimage\n\ndef normal_cube_to_tensor(cube_data):\n    \"\"\"\n        将cube 数据归一化并转换为 pytorch tensor 。 用在训练和推理过程\n    :param cube_data: shape为 [32,32,32] 的 ndarray\n    :return:\n    \"\"\"\n    cube_data = cube_data.astype(np.float32)\n    # 归一化到 [0, 1] 范围\n    min_val = np.min(cube_data)\n    max_val = np.max(cube_data)\n    data_range = max_val - min_val\n    # 避免除以零\n    if data_range < 1e-10:\n        normalized_cube = np.zeros_like(cube_data)\n    else:\n        normalized_cube = (cube_data - min_val) / data_range\n    # 检查是否有无效值并修复\n    if np.isnan(normalized_cube).any() or np.isinf(normalized_cube).any():\n        normalized_cube = np.nan_to_num(normalized_cube, nan=0.0, posinf=1.0, neginf=0.0)\n    # 转换为PyTorch张量并添加批次和通道维度\n    cube_tensor = torch.from_numpy(normalized_cube).float().unsqueeze(0).unsqueeze(0)  # (1, 1, 32, 32, 32)\n    return cube_tensor\n\n\n@dataclass\nclass NoduleCube:\n    \"\"\"\n    肺结节立方体类，表示肺结节区域的3D立方体数据\n    与CT数据无关，仅处理已提取的立方体数据\n    \"\"\"\n    # 基本属性\n    cube_size: int = 64  # 立方体大小（默认64x64x64）\n    pixel_data: Optional[np.ndarray] = None  # 像素数据 shape: [cube_size, cube_size, cube_size]\n    \n    # 结节特征\n    center_x: int = 0  # 结节中心x坐标\n    center_y: int = 0  # 结节中心y坐标\n    center_z: int = 0  # 结节中心z坐标\n    radius: float = 0.0  # 结节半径\n    malignancy: int = 0  # 恶性度 (0 为良性 / 1 为恶性)\n    \n    # 文件路径\n    npy_path: str = \"\"  # npy文件路径\n    png_path: str = \"\"  # png文件路径\n\n    def __post_init__(self):\n        \"\"\"初始化后调用\"\"\"\n        # 如果提供了npy_path但没有pixel_data，尝试加载\n        if self.npy_path and self.pixel_data is None:\n            self.load_from_npy()\n        # 如果提供了png_path但没有pixel_data，尝试加载\n        elif self.png_path and self.pixel_data is None:\n            self.load_from_png()\n\n    def load_from_npy(self) -> None:\n        \"\"\"从NPY文件加载立方体数据\"\"\"\n        if not os.path.exists(self.npy_path):\n            raise FileNotFoundError(f\"文件不存在: {self.npy_path}\")\n            \n        try:\n            self.pixel_data = np.load(self.npy_path)\n            # 验证尺寸\n            if len(self.pixel_data.shape) != 3:\n                raise ValueError(f\"像素数据必须是3D数组，当前形状: {self.pixel_data.shape}\")\n            \n            # 如果尺寸不匹配，调整大小\n            if (self.pixel_data.shape[0] != self.cube_size or \n                self.pixel_data.shape[1] != self.cube_size or \n                self.pixel_data.shape[2] != self.cube_size):\n                self.resize(self.cube_size)\n                \n        except Exception as e:\n            raise ValueError(f\"加载NPY文件时出错: {e}\")\n\n    def save_to_npy(self, output_path: str) -> str:\n        \"\"\"\n        将立方体数据保存为NPY文件\n        \n        Args:\n            output_path: 输出路径\n            \n        Returns:\n            保存的文件路径\n        \"\"\"\n        if self.pixel_data is None:\n            raise ValueError(\"没有像素数据可保存\")\n            \n        # 确保目录存在\n        os.makedirs(os.path.dirname(output_path), exist_ok=True)\n        np.save(output_path, self.pixel_data)\n        self.npy_path = output_path\n        return output_path\n    def save_to_png(self, output_path: str) -> str:\n        \"\"\"\n        将立方体数据保存为PNG图像（8x8网格布局）\n        \n        Args:\n            output_path: 输出PNG文件路径\n            \n        Returns:\n            保存的文件路径\n        \"\"\"\n        if self.pixel_data is None:\n            raise ValueError(\"没有像素数据可保存\")\n            \n        # 确保目录存在\n        os.makedirs(os.path.dirname(output_path), exist_ok=True)\n        \n        # 计算每个切片在最终图像中的位置（8行8列布局）\n        rows, cols = 8, 8\n        if self.cube_size != 64:\n            # 如果不是64x64x64，计算合适的行列数，保持接近正方形\n            total_slices = self.cube_size\n            rows = int(np.sqrt(total_slices))\n            while total_slices % rows != 0:\n                rows -= 1\n            cols = total_slices // rows\n        \n        # 创建拼接图像\n        img_height = self.cube_size\n        img_width = self.cube_size\n        combined_img = np.zeros((rows * img_height, cols * img_width), dtype=np.uint8)\n        \n        # 填充拼接图像\n        for i in range(self.cube_size):\n            row = i // cols\n            col = i % cols\n            \n            slice_data = self.pixel_data[i]\n            \n            # 确保数据在0-255范围内\n            if slice_data.max() <= 1.0:\n                slice_data = (slice_data * 255).astype(np.uint8)\n            else:\n                slice_data = slice_data.astype(np.uint8)\n            \n            # 将切片放入拼接图像\n            y_start = row * img_height\n            x_start = col * img_width\n            combined_img[y_start:y_start + img_height, x_start:x_start + img_width] = slice_data\n        \n        # 保存拼接图像\n        cv2.imwrite(output_path, combined_img)\n        self.png_path = output_path\n        return output_path\n\n    def load_from_png(self) -> None:\n        \"\"\"从PNG图像加载立方体数据（8x8网格布局）\"\"\"\n        if not os.path.exists(self.png_path):\n            raise FileNotFoundError(f\"文件不存在: {self.png_path}\")\n            \n        try:\n            # 读取PNG图像\n            img = cv2.imread(self.png_path, cv2.IMREAD_GRAYSCALE)\n            \n            # 确定行列数\n            rows, cols = 8, 8\n            if self.cube_size != 64:\n                # 如果不是64x64x64，计算合适的行列数\n                total_slices = self.cube_size\n                rows = int(np.sqrt(total_slices))\n                while total_slices % rows != 0:\n                    rows -= 1\n                cols = total_slices // rows\n            \n            # 确认图像尺寸正确\n            expected_height = rows * self.cube_size\n            expected_width = cols * self.cube_size\n            if img.shape[0] != expected_height or img.shape[1] != expected_width:\n                raise ValueError(f\"图像尺寸不匹配: 期望{expected_height}x{expected_width}, 实际{img.shape[0]}x{img.shape[1]}\")\n            \n            # 创建3D数组\n            cube_data = np.zeros((self.cube_size, self.cube_size, self.cube_size), dtype=np.float32)\n            \n            # 从PNG图像提取每个切片\n            for i in range(self.cube_size):\n                row = i // cols\n                col = i % cols\n                \n                y_start = row * self.cube_size\n                x_start = col * self.cube_size\n                \n                slice_data = img[y_start:y_start + self.cube_size, x_start:x_start + self.cube_size]\n                cube_data[i] = slice_data.astype(np.float32) / 255.0  # 归一化到[0,1]范围\n            \n            self.pixel_data = cube_data\n            \n        except Exception as e:\n            raise ValueError(f\"加载PNG文件时出错: {e}\")\n\n    def set_cube_data(self, pixel_data: np.ndarray) -> None:\n        \"\"\"\n        设置立方体像素数据\n        \n        Args:\n            pixel_data: 3D像素数据\n        \"\"\"\n        if len(pixel_data.shape) != 3:\n            raise ValueError(f\"像素数据必须是3D数组，当前形状: {pixel_data.shape}\")\n        \n        self.pixel_data = pixel_data\n        \n        # 如果尺寸不匹配，调整大小\n        if (self.pixel_data.shape[0] != self.cube_size or \n            self.pixel_data.shape[1] != self.cube_size or \n            self.pixel_data.shape[2] != self.cube_size):\n            self.resize(self.cube_size)\n\n    def resize(self, new_size: int) -> None:\n        \"\"\"\n        调整立方体尺寸\n        \n        Args:\n            new_size: 新的立方体尺寸\n        \"\"\"\n        if self.pixel_data is None:\n            raise ValueError(\"没有像素数据可调整大小\")\n        \n        # 计算缩放因子\n        zoom_factors = [new_size / self.pixel_data.shape[0],\n                         new_size / self.pixel_data.shape[1],\n                         new_size / self.pixel_data.shape[2]]\n        \n        # 使用scipy的ndimage进行重采样\n        self.pixel_data = ndimage.zoom(self.pixel_data, zoom_factors, mode='nearest')\n        self.cube_size = new_size\n        \n    def augment(self, rotation: bool = True, flip_axis: int = -1, noise: bool = True) -> 'NoduleCube':\n        \"\"\"\n        数据增强\n        \n        Args:\n            rotation: 是否进行旋转增强\n            flip_axis: 是否进行翻转增强，默认为-1（不翻转）\n            noise: 是否添加噪声\n            \n        Returns:\n            增强后的新立方体实例\n        \"\"\"\n        if self.pixel_data is None:\n            raise ValueError(\"没有像素数据可增强\")\n        \n        # 创建副本\n        augmented_cube = self.pixel_data.copy()\n        \n        # 旋转增强\n        if rotation:\n            # 随机选择旋转角度\n            angles = np.random.uniform(-20, 20, 3)  # 在xyz三个方向上随机旋转\n            augmented_cube = ndimage.rotate(augmented_cube, angles[0], axes=(1, 2), reshape=False, mode='nearest')\n            augmented_cube = ndimage.rotate(augmented_cube, angles[1], axes=(0, 2), reshape=False, mode='nearest')\n            augmented_cube = ndimage.rotate(augmented_cube, angles[2], axes=(0, 1), reshape=False, mode='nearest')\n        \n        # 翻转增强\n        if flip_axis >=0:\n            augmented_cube = np.flip(augmented_cube, axis=flip_axis)\n        \n        # 添加噪声\n        if noise:\n            # 添加随机高斯噪声\n            noise_level = np.random.uniform(0.0, 0.05)\n            noise_array = np.random.normal(0, noise_level, augmented_cube.shape)\n            augmented_cube = augmented_cube + noise_array\n            # 确保值在[0,1]范围内\n            augmented_cube = np.clip(augmented_cube, 0, 1)\n        \n        # 创建新实例\n        new_cube = NoduleCube(\n            cube_size=self.cube_size,\n            center_x=self.center_x,\n            center_y=self.center_y,\n            center_z=self.center_z,\n            radius=self.radius,\n            malignancy=self.malignancy\n        )\n        \n        new_cube.set_cube_data(augmented_cube)\n        return new_cube\n\n    def visualize_3d(self, output_path: Optional[str] = None, show: bool = True) -> None:\n        \"\"\"\n        可视化立方体数据\n        \n        Args:\n            output_path: 可选的输出路径，如果提供则保存图像\n            show: 是否显示图像\n        \"\"\"\n        if self.pixel_data is None:\n            raise ValueError(\"没有像素数据可视化\")\n            \n        # 创建图像\n        fig, axes = plt.subplots(2, 3, figsize=(12, 8))\n        \n        # 获取中心切片\n        center_z = self.pixel_data.shape[0] // 2\n        center_y = self.pixel_data.shape[1] // 2\n        center_x = self.pixel_data.shape[2] // 2\n        \n        # 显示三个正交平面\n        slice_xy = self.pixel_data[center_z, :, :]\n        slice_xz = self.pixel_data[:, center_y, :]\n        slice_yz = self.pixel_data[:, :, center_x]\n        \n        # 显示三个正交视图\n        axes[0, 0].imshow(slice_xy, cmap='gray')\n        axes[0, 0].set_title(f'轴向视图 (Z={center_z})')\n        \n        axes[0, 1].imshow(slice_xz, cmap='gray')\n        axes[0, 1].set_title(f'矢状位视图 (Y={center_y})')\n        \n        axes[0, 2].imshow(slice_yz, cmap='gray')\n        axes[0, 2].set_title(f'冠状位视图 (X={center_x})')\n        \n        # 3D渲染视图（使用MIP: Maximum Intensity Projection）\n        mip_xy = np.max(self.pixel_data, axis=0)\n        mip_xz = np.max(self.pixel_data, axis=1)\n        mip_yz = np.max(self.pixel_data, axis=2)\n        \n        axes[1, 0].imshow(mip_xy, cmap='gray')\n        axes[1, 0].set_title('最大强度投影 (轴向)')\n        \n        axes[1, 1].imshow(mip_xz, cmap='gray')\n        axes[1, 1].set_title('最大强度投影 (矢状位)')\n        \n        axes[1, 2].imshow(mip_yz, cmap='gray')\n        axes[1, 2].set_title('最大强度投影 (冠状位)')\n        \n        # 添加结节信息\n        nodule_info = f\"结节中心: ({self.center_x}, {self.center_y}, {self.center_z})\\n\"\n        nodule_info += f\"半径: {self.radius:.1f}\\n\"\n        nodule_info += f\"恶性度: {'恶性' if self.malignancy == 1 else '良性'}\"\n        \n        fig.suptitle(nodule_info, fontsize=12)\n        plt.tight_layout()\n        \n        if output_path:\n            plt.savefig(output_path, dpi=200, bbox_inches='tight')\n        \n        if show:\n            plt.show()\n        else:\n            plt.close(fig)\n            \n    @classmethod\n    def from_npy(cls, file_path: str, cube_size: int = 64) -> 'NoduleCube':\n        \"\"\"\n        从NPY文件创建立方体实例\n        \n        Args:\n            file_path: NPY文件路径\n            cube_size: 立方体大小\n            \n        Returns:\n            NoduleCube实例\n        \"\"\"\n        cube = cls(cube_size=cube_size, npy_path=file_path)\n        cube.load_from_npy()\n        return cube\n    \n    @classmethod\n    def from_png(cls, file_path: str, cube_size: int = 64) -> 'NoduleCube':\n        \"\"\"\n        从PNG文件创建立方体实例\n        \n        Args:\n            file_path: PNG文件路径\n            cube_size: 立方体大小\n            \n        Returns:\n            NoduleCube实例\n        \"\"\"\n        cube = cls(cube_size=cube_size, png_path=file_path)\n        cube.load_from_png()\n        return cube\n        \n    @classmethod\n    def from_array(cls, \n                  pixel_data: np.ndarray, \n                  center_x: int = 0, \n                  center_y: int = 0, \n                  center_z: int = 0,\n                  radius: float = 0.0,\n                  malignancy: int = 0) -> 'NoduleCube':\n        \"\"\"\n        从numpy数组创建立方体实例\n        \n        Args:\n            pixel_data: 3D像素数据\n            center_x: 中心点X坐标\n            center_y: 中心点Y坐标\n            center_z: 中心点Z坐标\n            radius: 结节半径\n            malignancy: 恶性度(0=良性, 1=恶性)\n            \n        Returns:\n            NoduleCube实例\n        \"\"\"\n        if len(pixel_data.shape) != 3:\n            raise ValueError(f\"像素数据必须是3D数组，当前形状: {pixel_data.shape}\")\n            \n        cube_size = pixel_data.shape[0]\n        if pixel_data.shape[1] != cube_size or pixel_data.shape[2] != cube_size:\n            raise ValueError(f\"像素数据必须是立方体形状，当前形状: {pixel_data.shape}\")\n            \n        cube = cls(\n            cube_size=cube_size,\n            center_x=center_x,\n            center_y=center_y,\n            center_z=center_z,\n            radius=radius,\n            malignancy=malignancy\n        )\n        \n        cube.set_cube_data(pixel_data)\n        return cube\n\n\n"
  },
  {
    "path": "deploy/backend/dataclass/__init__.py",
    "content": ""
  },
  {
    "path": "deploy/backend/detector.py",
    "content": "import os\nimport sys\nimport time\nimport logging\nimport numpy as np\nimport pandas as pd\nimport torch\nfrom threading import Thread, Lock\nimport json\n\n# 添加项目根目录到系统路径\nsys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))\n\n# 导入肺结节检测模块\nfrom inference.pytorch_nodule_detector import (\n    load_ct_data, \n    load_model, \n    get_lung_bounds, \n    scan_ct_data, \n    reduce_overlapping_nodules, \n    filter_false_positives, \n    format_results\n)\n\n# 配置日志\nlogging.basicConfig(level=logging.INFO)\nlogger = logging.getLogger(__name__)\n\n# 检测会话状态\nsession_states = {}\nsession_locks = {}\n\n# 检测状态常量\nSTATUS_LOADING = \"loading\"       # 加载数据\nSTATUS_PREPROCESSING = \"preprocessing\"  # 预处理\nSTATUS_SCANNING = \"scanning\"     # 扫描检测\nSTATUS_FILTERING = \"filtering\"   # 过滤假阳性\nSTATUS_COMPLETED = \"completed\"   # 完成\nSTATUS_ERROR = \"error\"           # 错误\n\nclass NoduleDetector:\n    \"\"\"肺结节检测器，包装pytorch_nodule_detector.py的功能\"\"\"\n    \n    def __init__(self, model_path=None, device='cuda'):\n        \"\"\"初始化检测器\n        \n        Args:\n            model_path: 模型路径\n            device: 使用设备 (cuda或cpu)\n        \"\"\"\n        self.model_path = model_path\n        self.device = torch.device(device if torch.cuda.is_available() and device == 'cuda' else 'cpu')\n        self.model = None\n        print(\"给出的模型路径是\\t\", model_path)\n        # 如果提供了模型路径，加载模型\n        if model_path and os.path.exists(model_path):\n            print(\"模型已经完成加载!\")\n            self.load_model(model_path)\n        \n        # 添加完成回调\n        self.completion_callback = None\n            \n    def load_model(self, model_path):\n        \"\"\"加载模型\n        \n        Args:\n            model_path: 模型路径\n            \n        Returns:\n            加载是否成功\n        \"\"\"\n        try:\n            if not os.path.exists(model_path):\n                logger.error(f\"模型文件不存在: {model_path}\")\n                return False\n                \n            self.model, self.device = load_model(model_path, self.device)\n            self.model_path = model_path\n            return True\n        except Exception as e:\n            logger.error(f\"加载模型时出错: {e}\")\n            return False\n            \n    def detect(self, file_path, session_id, patient_id=None):\n        \"\"\"启动肺结节检测\n        \n        Args:\n            file_path: CT文件或文件夹路径\n            session_id: 会话ID\n            patient_id: 患者ID\n            \n        Returns:\n            布尔值，表示检测是否成功启动\n        \"\"\"\n        if not self.model:\n            logger.error(\"模型未加载\")\n            return False\n            \n        if session_id in session_states:\n            logger.warning(f\"会话 {session_id} 已存在，将被覆盖\")\n        print(\"进入detect 中的 file_path\\t\", file_path)\n        # 初始化会话状态\n        session_states[session_id] = {\n            \"status\": STATUS_LOADING,\n            \"progress\": 0,\n            \"message\": \"正在加载CT数据...\",\n            \"started_at\": time.time(),\n            \"patient_id\": patient_id,\n            \"file_path\": file_path,\n            \"ct_data\": None,\n            \"nodules\": None,\n            \"lung_bounds\": None,\n            \"error\": None\n        }\n        \n        # 创建锁\n        if session_id not in session_locks:\n            session_locks[session_id] = Lock()\n            \n        # 启动检测线程\n        thread = Thread(target=self._detect_thread, args=(file_path, session_id, patient_id))\n        thread.daemon = True\n        thread.start()\n        \n        return True\n        \n    def _detect_thread(self, file_path, session_id, patient_id):\n        \"\"\"检测线程\n        \n        Args:\n            file_path: CT文件或文件夹路径\n            session_id: 会话ID\n            patient_id: 患者ID\n        \"\"\"\n        try:\n            # 加载CT数据\n            self._update_session(session_id, {\n                \"status\": STATUS_LOADING,\n                \"progress\": 0,\n                \"message\": \"正在加载CT数据...\"\n            })\n            \n            ct_data = load_ct_data(file_path)\n            \n            # 检查CT数据是否加载成功\n            if ct_data is None:\n                raise ValueError(\"CT数据加载失败\")\n                \n            # 更新会话状态\n            self._update_session(session_id, {\n                \"ct_data\": ct_data,\n                \"progress\": 10,\n                \"message\": \"CT数据加载完成，开始进行肺部分割...\"\n            })\n            \n            # 获取肺部边界信息\n            self._update_session(session_id, {\n                \"status\": STATUS_PREPROCESSING,\n                \"progress\": 20,\n                \"message\": \"正在进行肺部分割...\"\n            })\n            \n            # 肺部分割在load_ct_data中已完成，这里获取肺部边界信息\n            lung_bounds = get_lung_bounds(ct_data.lung_seg_mask)\n            if lung_bounds is None:\n                raise ValueError(\"未能找到有效的肺部区域\")\n                \n            # 更新会话状态\n            self._update_session(session_id, {\n                \"lung_bounds\": lung_bounds,\n                \"progress\": 30,\n                \"message\": \"肺部分割完成，开始检测结节...\"\n            })\n            \n            # 开始扫描检测\n            self._update_session(session_id, {\n                \"status\": STATUS_SCANNING,\n                \"progress\": 40,\n                \"message\": \"正在扫描肺部区域，检测结节...\"\n            })\n            # 创建logger用于接收进度更新\n            progress_logger = self._create_progress_logger(session_id)\n            # 执行扫描\n            results_df = scan_ct_data(ct_data, self.model, self.device, progress_logger)\n            # 更新会话状态\n            self._update_session(session_id, {\n                \"progress\": 80,\n                \"message\": f\"扫描完成，初步检测到 {len(results_df)} 个可能的结节，进行假阳性过滤...\"\n            })\n            \n            # 合并重叠结节\n            self._update_session(session_id, {\n                \"status\": STATUS_FILTERING,\n                \"progress\": 85,\n                \"message\": \"正在合并重叠结节...\"\n            })\n            reduced_df = reduce_overlapping_nodules(results_df)\n            \n            # 过滤假阳性\n            self._update_session(session_id, {\n                \"progress\": 90,\n                \"message\": f\"初步合并后剩余 {len(reduced_df)} 个结节，正在进行假阳性过滤...\"\n            })\n            filtered_df = filter_false_positives(reduced_df, ct_data)\n            \n            # 格式化结果\n            self._update_session(session_id, {\n                \"progress\": 95,\n                \"message\": f\"过滤完成，最终检测到 {len(filtered_df)} 个结节，正在生成结果...\"\n            })\n            final_df = format_results(filtered_df, ct_data, patient_id or session_id)\n            \n            # 提取结节立方体\n            nodules = self._extract_nodule_cubes(ct_data, final_df)\n            \n            # 更新会话状态\n            final_results = {\n                \"status\": STATUS_COMPLETED,\n                \"progress\": 100,\n                \"message\": f\"检测完成，共发现 {len(nodules)} 个结节\",\n                \"nodules\": nodules,\n                \"completed_at\": time.time()\n            }\n            \n            self._update_session(session_id, final_results)\n            \n            # 调用完成回调函数\n            if self.completion_callback:\n                try:\n                    self.completion_callback(session_id, {\"nodules\": nodules})\n                except Exception as callback_error:\n                    logger.error(f\"执行完成回调函数时出错: {str(callback_error)}\", exc_info=True)\n            \n        except Exception as e:\n            logger.error(f\"检测过程出错: {str(e)}\", exc_info=True)\n            self._update_session(session_id, {\n                \"status\": STATUS_ERROR,\n                \"message\": f\"检测失败: {str(e)}\",\n                \"error\": str(e)\n            })\n    \n    def _create_progress_logger(self, session_id):\n        \"\"\"创建用于接收进度更新的logger\n        \n        Args:\n            session_id: 会话ID\n            \n        Returns:\n            自定义logger对象\n        \"\"\"\n        class ProgressLogger:\n            def __init__(self, outer_self, session_id):\n                self.outer_self = outer_self\n                self.session_id = session_id\n                self.last_progress_time = time.time()\n                \n            def info(self, message):\n                # 解析进度信息\n                if \"处理进度:\" in message:\n                    try:\n                        # 从消息中提取进度百分比\n                        progress_part = message.split(\"处理进度:\")[1].split(\"%\")[0].strip()\n                        progress_parts = progress_part.split('/')\n                        if len(progress_parts) == 2:\n                            current, total = map(int, progress_parts)\n                            progress = min(40 + int(current / total * 40), 80)  # 扫描进度从40%到80%\n                            \n                            # 每秒最多更新一次进度\n                            current_time = time.time()\n                            if current_time - self.last_progress_time > 1:\n                                self.last_progress_time = current_time\n                                self.outer_self._update_session(self.session_id, {\n                                    \"progress\": progress,\n                                    \"message\": message\n                                })\n                    except:\n                        pass\n                \n                # 记录其他重要消息\n                if \"扫描完成\" in message:\n                    self.outer_self._update_session(self.session_id, {\n                        \"progress\": 80,\n                        \"message\": message\n                    })\n        \n        return ProgressLogger(self, session_id)\n    \n    def _extract_nodule_cubes(self, ct_data, nodules_df, cube_size=32):\n        \"\"\"从CT数据中提取结节立方体\n        \n        Args:\n            ct_data: CTData对象\n            nodules_df: 结节DataFrame\n            cube_size: 立方体大小\n            \n        Returns:\n            结节列表，每个结节包含立方体数据和相关信息\n        \"\"\"\n        nodule_list = []\n        \n        for _, row in nodules_df.iterrows():\n            try:\n                # 获取体素坐标\n                x = int(row['voxel_x'])\n                y = int(row['voxel_y'])\n                z = int(row['voxel_z'])\n                \n                # 计算立方体区域边界\n                half_size = cube_size // 2\n                z_min = max(0, z - half_size)\n                y_min = max(0, y - half_size)\n                x_min = max(0, x - half_size)\n                z_max = min(ct_data.lung_seg_img.shape[0], z + half_size)\n                y_max = min(ct_data.lung_seg_img.shape[1], y + half_size)\n                x_max = min(ct_data.lung_seg_img.shape[2], x + half_size)\n                \n                # 提取立方体\n                cube = ct_data.lung_seg_img[z_min:z_max, y_min:y_max, x_min:x_max]\n                \n                # 如果立方体大小不符合要求，进行调整\n                if cube.shape != (cube_size, cube_size, cube_size):\n                    # 使用零填充调整大小\n                    padded_cube = np.zeros((cube_size, cube_size, cube_size), dtype=cube.dtype)\n                    padded_cube[:min(cube_size, cube.shape[0]), \n                               :min(cube_size, cube.shape[1]), \n                               :min(cube_size, cube.shape[2])] = cube[:min(cube_size, cube.shape[0]), \n                                                                      :min(cube_size, cube.shape[1]), \n                                                                      :min(cube_size, cube.shape[2])]\n                    cube = padded_cube\n                \n                # 添加结节信息\n                nodule_info = {\n                    'id': int(row['nodule_id']),\n                    'cube': cube.tolist(),  # 转换为列表以便JSON序列化\n                    'voxel_coords': [x, y, z],\n                    'world_coords': [float(row['world_x']), float(row['world_y']), float(row['world_z'])],\n                    'diameter_mm': float(row['diameter_mm']),\n                    'probability': float(row['prob'])\n                }\n                \n                nodule_list.append(nodule_info)\n                \n            except Exception as e:\n                logger.error(f\"提取结节立方体时出错: {str(e)}\", exc_info=True)\n        \n        return nodule_list\n    \n    def _update_session(self, session_id, updates):\n        \"\"\"更新会话状态\n        \n        Args:\n            session_id: 会话ID\n            updates: 要更新的字段\n        \"\"\"\n        if session_id not in session_states:\n            logger.warning(f\"会话 {session_id} 不存在\")\n            return\n            \n        # 获取锁\n        with session_locks[session_id]:\n            for key, value in updates.items():\n                session_states[session_id][key] = value\n                \n    def get_session_state(self, session_id):\n        \"\"\"获取会话状态\n        \n        Args:\n            session_id: 会话ID\n            \n        Returns:\n            会话状态字典\n        \"\"\"\n        if session_id not in session_states:\n            return None\n            \n        # 获取锁\n        with session_locks[session_id]:\n            # 创建副本以避免修改原始状态\n            state = session_states[session_id].copy()\n            \n            # 移除不可序列化的字段\n            if 'ct_data' in state:\n                del state['ct_data']\n            if 'lung_bounds' in state:\n                bounds_info = {}\n                if state['lung_bounds']:\n                    bounds_info = {\n                        'z_min': state['lung_bounds']['z_min'],\n                        'z_max': state['lung_bounds']['z_max'],\n                        'region_count': len(state['lung_bounds']['regions'])\n                    }\n                state['lung_bounds'] = bounds_info\n            \n            return state\n            \n    def set_completion_callback(self, callback):\n        \"\"\"设置检测完成后的回调函数\n        \n        Args:\n            callback: 回调函数，接收session_id和results参数\n        \"\"\"\n        self.completion_callback = callback\n            \ndef get_detector_instance(model_path=None):\n    \"\"\"获取检测器实例（单例模式）\n    \n    Args:\n        model_path: 模型路径\n        \n    Returns:\n        NoduleDetector实例\n    \"\"\"\n    if not hasattr(get_detector_instance, 'instance'):\n        get_detector_instance.instance = NoduleDetector(model_path)\n    elif model_path and get_detector_instance.instance.model_path != model_path:\n        # 如果提供了不同的模型路径，重新加载模型\n        get_detector_instance.instance.load_model(model_path)\n        \n    return get_detector_instance.instance "
  },
  {
    "path": "deploy/backend/models/pytorch_c3d_tiny.py",
    "content": "import torch.nn as nn\nimport torchvision.transforms as transforms\n\n# my_tranform =transforms.Compose([\n#     # transforms.Resize((32,32,32)),\n#     transforms.ToTensor(),\n#     transforms.Normalize((0.5,0.5,0.5), (0.5, 0.5,0.5))\n# ])\n\n\nclass C3dTiny(nn.Module):\n    def __init__(self):\n        super().__init__()\n        # 第一个3d卷积组\n        self.conv_block1 = nn.Sequential(\n            nn.Conv3d(in_channels=1, kernel_size=3, padding = 1, out_channels=64),\n            # 原网络结构没有，新增的\n            nn.BatchNorm3d(64),\n            nn.ReLU(),\n            nn.MaxPool3d(kernel_size=(1,2,2), stride = (1,2,2))\n        )\n        #\n        self.conv_block2 = nn.Sequential(\n            nn.Conv3d(in_channels=64, kernel_size=3, padding = 1, out_channels=128),\n            # 原网络结构没有，新增的\n            nn.BatchNorm3d(128),\n            nn.ReLU(),\n            nn.MaxPool3d(kernel_size=2)\n        )\n        self.drop_out1 = nn.Dropout(0.2)\n        #\n        self.conv_block3 = nn.Sequential(\n            nn.Conv3d(in_channels = 128, kernel_size=3, padding = 1, out_channels=256),\n            nn.BatchNorm3d(256),\n            nn.ReLU(),\n            nn.Conv3d(in_channels=256, kernel_size=3, padding = 1, out_channels=256),\n            nn.BatchNorm3d(256),\n            nn.ReLU(),\n            nn.MaxPool3d(kernel_size=2)\n        )\n        self.drop_out2 = nn.Dropout(0.2)\n        #\n        self.conv_block4 = nn.Sequential(\n            nn.Conv3d(in_channels = 256, kernel_size = 3, padding = 1, out_channels=512),\n            nn.BatchNorm3d(512),\n            nn.ReLU(),\n            nn.Conv3d(in_channels = 512, kernel_size = 3, padding = 1, out_channels = 512),\n            nn.BatchNorm3d(512),\n            nn.ReLU(),\n            nn.MaxPool3d(kernel_size=2)\n        )\n        self.drop_out3 = nn.Dropout(0.2)\n        self.flatten = nn.Flatten()\n        #计算输入特征数量：\n        # 原始输入为32x32x32，经过pool1(1,2,2)后变为32x16x16\n        # 经过pool2(2,2,2)后变为16x8x8\n        # 经过pool3(2,2,2)后变为8x4x4\n        # 经过pool4(2,2,2)后变为4x2x2\n        # 因此最终特征图大小为4x2x2，通道数为512\n        self.fc1 = nn.Sequential(\n            nn.Linear(512 * 4 * 2 * 2, 512),\n            nn.ReLU()\n        )\n        self.fc2 = nn.Linear(512, 2)\n\n    def forward(self, x):\n        x = self.conv_block1(x)\n        x = self.conv_block2(x)\n        x = self.drop_out1(x)\n        x = self.conv_block3(x)\n        x = self.drop_out2(x)\n        x = self.conv_block4(x)\n        x = self.drop_out3(x)\n        x = self.flatten(x)\n        x = self.fc1(x)\n        x = self.fc2(x)\n        return x"
  },
  {
    "path": "deploy/backend/models/pytorch_nodule_detector.py",
    "content": "import os\nimport numpy as np\nimport pandas as pd\nimport torch\nfrom torch.nn import functional as F\nfrom datetime import datetime\nimport logging\nimport time\nfrom scipy import ndimage\nfrom data.dataclass.CTData import CTData\nfrom data.dataclass.NoduleCube import normal_cube_to_tensor\nfrom deploy.backend.preprocessing.luna16_invalid_nodule_filter import nodule_valid\nfrom models.pytorch_c3d_tiny import C3dTiny\n\n# 推理参数\nCUBE_SIZE = 32  # 扫描立方体大小 32x32x32\nSCAN_STEP = 10  # 扫描步长，每次移动10个像素\nPROB_THRESHOLD = 0.8  # 阈值: 大于此概率才视为结节\n\n# 设置日志\ndef setup_logger(log_dir=\"./inference_logs\"):\n    \"\"\"设置日志配置\"\"\"\n    os.makedirs(log_dir, exist_ok=True)\n    log_file = os.path.join(log_dir, f\"inference_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log\")\n    \n    # 创建logger\n    logger = logging.getLogger('nodule_detection')\n    logger.setLevel(logging.INFO)\n    \n    # 创建文件处理器\n    file_handler = logging.FileHandler(log_file)\n    file_handler.setLevel(logging.INFO)\n    \n    # 创建控制台处理器\n    console_handler = logging.StreamHandler()\n    console_handler.setLevel(logging.INFO)\n    \n    # 创建格式器\n    formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')\n    file_handler.setFormatter(formatter)\n    console_handler.setFormatter(formatter)\n    \n    # 添加处理器\n    logger.addHandler(file_handler)\n    logger.addHandler(console_handler)\n    \n    return logger\n\ndef load_ct_data(file_path):\n    \"\"\"\n        加载CT数据（支持MHD和DICOM）并进行预处理\n    Args:\n        file_path: CT文件或文件夹路径\n    Returns:\n        CTData对象\n    \"\"\"\n    # 判断是文件还是目录\n    if os.path.isfile(file_path):\n        # 假设是MHD文件\n        if file_path.endswith('.mhd'):\n            ct_data = CTData.from_mhd(file_path)\n        else:\n            raise ValueError(f\"不支持的文件类型: {file_path}\")\n    elif os.path.isdir(file_path):\n        # 假设是DICOM文件夹\n        ct_data = CTData.from_dicom(file_path)\n    else:\n        raise ValueError(f\"指定路径不存在: {file_path}\")\n    # 重采样到1mm间距\n    ct_data = ct_data.resample_pixel(new_spacing=[1, 1, 1])\n    # 肺部区域分割\n    ct_data.filter_lung_img_mask()\n    return ct_data\n\ndef load_model(evaL_model_path, device='cuda'):\n    \"\"\"\n    加载PyTorch模型\n    \n    Args:\n        evaL_model_path: 模型权重文件路径\n        device: 计算设备 ('cuda' 或 'cpu')\n\n    Returns:\n        加载好权重的模型\n    \"\"\"\n    model = C3dTiny().to(device)\n    # 加载权重\n    model.load_state_dict(torch.load(evaL_model_path, map_location=device))\n    model.eval()\n    return model, device\n\ndef get_lung_bounds(lung_mask):\n    \"\"\"获取肺部掩码的边界框，考虑左右肺分离的情况\"\"\"\n    if lung_mask.sum() == 0:\n        return None\n    # 使用连通区域分析找出肺部区域\n    labeled_mask, num_features = ndimage.label(lung_mask > 0)\n    # 如果连通区域过多，只考虑最大的几个区域（通常是左右肺）\n    if num_features > 2:\n        # 计算每个标签区域的体素数量\n        region_sizes = np.array([(labeled_mask == i).sum() for i in range(1, num_features + 1)])\n        # 只保留最大的2个区域（左右肺）\n        valid_labels = np.argsort(region_sizes)[-2:] + 1\n        # 创建新的掩码，只包含最大的几个区域\n        refined_mask = np.zeros_like(labeled_mask)\n        for label in valid_labels:\n            refined_mask[labeled_mask == label] = 1\n    else:\n        refined_mask = lung_mask > 0\n    \n    # 根据z轴切片，计算每个切片的肺部区域\n    z_ranges = []\n    margin = 5  # 切片边距\n    \n    # 遍历每个z轴切片\n    for z in range(refined_mask.shape[0]):\n        slice_mask = refined_mask[z]\n        if slice_mask.sum() > 100:  # 如果切片包含足够的肺部体素\n            y_indices, x_indices = np.where(slice_mask)\n            if len(y_indices) > 0:\n                y_min = max(0, y_indices.min() - margin)\n                y_max = min(refined_mask.shape[1], y_indices.max() + margin)\n                x_min = max(0, x_indices.min() - margin)\n                x_max = min(refined_mask.shape[2], x_indices.max() + margin)\n                z_ranges.append((z, y_min, y_max, x_min, x_max))\n    \n    if not z_ranges:\n        return None\n    \n    # 确定整体z轴范围\n    z_min = z_ranges[0][0]\n    z_max = z_ranges[-1][0] + 1\n    \n    # 收集所有y和x范围\n    scan_regions = []\n    for z_slice, y_min, y_max, x_min, x_max in z_ranges:\n        scan_regions.append({\n            'z': z_slice,\n            'y_min': y_min,\n            'y_max': y_max,\n            'x_min': x_min,\n            'x_max': x_max\n        })\n    \n    return {\n        'z_min': z_min,\n        'z_max': z_max,\n        'regions': scan_regions\n    }\n\ndef scan_ct_data(ct_data, model, device, logger, step=SCAN_STEP):\n    \"\"\"\n    扫描整个CT图像，预测结节位置 - 优化版\n    \n    Args:\n        ct_data: CTData对象\n        model: PyTorch模型\n        device: 计算设备\n        logger: 日志对象\n        step: 扫描步长\n        \n    Returns:\n        包含结节信息的DataFrame\n    \"\"\"\n    logger.info(\"开始扫描CT数据...\")\n    \n    # 获取肺部分割后的图像数据\n    lung_img = ct_data.lung_seg_img\n    lung_mask = ct_data.lung_seg_mask\n    \n    # 获取肺部边界信息\n    bounds = get_lung_bounds(lung_mask)\n    if bounds is None:\n        logger.warning(\"未能找到有效的肺部区域\")\n        return pd.DataFrame(columns=['voxel_coord_x', 'voxel_coord_y', 'voxel_coord_z', \n                                    'world_coord_x', 'world_coord_y', 'world_coord_z', 'prob'])\n    \n    logger.info(f\"已确定肺部区域: Z轴范围 {bounds['z_min']} 到 {bounds['z_max']}, 共 {len(bounds['regions'])} 个切片\")\n    \n    # 创建存储结果的列表\n    results = []\n    \n    # 计算需要扫描的总体素数估计\n    total_voxels = 0\n    for region in bounds['regions']:\n        y_range = region['y_max'] - region['y_min']\n        x_range = region['x_max'] - region['x_min']\n        total_voxels += (y_range // step + 1) * (x_range // step + 1)\n    \n    logger.info(f\"预计扫描体素数: {total_voxels}\")\n    \n    # 开始计时\n    start_time = time.time()\n    batch_size = 32  # 增大批处理大小提高GPU利用率\n    batch_inputs = []\n    batch_positions = []\n    # 跟踪进度\n    processed_voxels = 0\n    skipped_voxels = 0\n    # 设置肺部组织比例阈值\n    lung_tissue_threshold = 0.1  # 立方体中肺部组织的最小比例\n    # 逐切片扫描肺部区域\n    for z_idx, region in enumerate(bounds['regions']):\n        z = region['z']\n        # 检查是否可以放置一个完整的立方体\n        if z + CUBE_SIZE > lung_img.shape[0]:\n            continue\n        # 在当前切片上扫描\n        for y in range(region['y_min'], region['y_max'] - CUBE_SIZE + 1, step):\n            for x in range(region['x_min'], region['x_max'] - CUBE_SIZE + 1, step):\n                # 提取当前位置的肺部掩码立方体\n                mask_cube = lung_mask[z:z+CUBE_SIZE, y:y+CUBE_SIZE, x:x+CUBE_SIZE]\n                # 计算肺部组织比例\n                lung_ratio = np.mean(mask_cube)\n                # 如果肺部组织比例过低，跳过\n                if lung_ratio < lung_tissue_threshold:\n                    skipped_voxels += 1\n                    continue\n                # 提取当前位置的立方体\n                cube = lung_img[z:z+CUBE_SIZE, y:y+CUBE_SIZE, x:x+CUBE_SIZE]\n                # 预处理立方体数据\n                cube_tensor = normal_cube_to_tensor(cube)\n                cube_tensor = cube_tensor.unsqueeze(0)\n                # 添加到批处理\n                batch_inputs.append(cube_tensor)\n                batch_positions.append((z, y, x))\n                # 当批处理达到指定大小时进行预测\n                if len(batch_inputs) == batch_size:\n                    # 处理当前批次\n                    process_batch(batch_inputs, batch_positions, model, device, ct_data, results)\n                    batch_inputs = []\n                    batch_positions = []\n                processed_voxels += 1\n                # 定期报告进度\n                if (processed_voxels + skipped_voxels) % 1000 == 0:\n                    elapsed_time = time.time() - start_time\n                    progress = processed_voxels / total_voxels * 100 if total_voxels > 0 else 0\n                    logger.info(f\"处理进度: {processed_voxels}/{total_voxels} ({progress:.2f}%), \"\n                                f\"已跳过: {skipped_voxels}, 耗时: {elapsed_time:.2f}秒\")\n    \n    # 处理最后一个批次\n    if batch_inputs:\n        process_batch(batch_inputs, batch_positions, model, device, ct_data, results)\n    \n    # 创建DataFrame\n    if results:\n        results_df = pd.DataFrame(results)\n        logger.info(f\"扫描完成! 发现 {len(results_df)} 个可能的结节\")\n    else:\n        results_df = pd.DataFrame(columns=['voxel_coord_x', 'voxel_coord_y', 'voxel_coord_z', \n                                          'world_coord_x', 'world_coord_y', 'world_coord_z', 'prob'])\n        logger.info(\"扫描完成! 未发现任何结节\")\n    \n    return results_df\n\ndef process_batch(batch_inputs, batch_positions, model, device, ct_data, results):\n    \"\"\"处理一个批次的数据\"\"\"\n    # 合并批处理\n    batch_tensor = torch.cat(batch_inputs, dim=0).to(device)\n    # 预测\n    with torch.no_grad():\n        batch_outputs = model(batch_tensor)\n        batch_probs = F.softmax(batch_outputs, dim=1)[:, 1]  # 类别1的概率\n    # 处理每个预测结果\n    for i, prob in enumerate(batch_probs):\n        prob_value = prob.item()\n        if prob_value > PROB_THRESHOLD:\n            z_pos, y_pos, x_pos = batch_positions[i]\n            # 计算中心点坐标\n            center_z = z_pos + CUBE_SIZE // 2\n            center_y = y_pos + CUBE_SIZE // 2\n            center_x = x_pos + CUBE_SIZE // 2\n            # 将体素坐标转换为世界坐标 (mm)\n            world_coord = ct_data.voxel_to_world([center_x, center_y, center_z])\n            # 添加结果\n            results.append({\n                'voxel_coord_x': center_x,\n                'voxel_coord_y': center_y,\n                'voxel_coord_z': center_z,\n                'world_coord_x': world_coord[0],\n                'world_coord_y': world_coord[1],\n                'world_coord_z': world_coord[2],\n                'prob': prob_value\n            })\n\ndef reduce_overlapping_nodules(results_df, distance_threshold=15):\n    \"\"\"\n    合并重叠的结节预测，使用更严格的距离阈值\n    \n    Args:\n        results_df: 包含结节预测的DataFrame\n        distance_threshold: 合并的距离阈值(体素)\n        \n    Returns:\n        合并后的结节DataFrame\n    \"\"\"\n    if len(results_df) <= 1:\n        return results_df\n    # 按概率从高到低排序\n    sorted_df = results_df.sort_values('prob', ascending=False).reset_index(drop=True)\n    # 创建一个布尔掩码来标记要保留的行\n    keep_mask = np.ones(len(sorted_df), dtype=bool)\n    # 对每一行\n    for i in range(len(sorted_df)):\n        if not keep_mask[i]:\n            continue  # 如果此行已被标记为删除，则跳过\n        # 获取当前结节的坐标\n        current = sorted_df.iloc[i]\n        # 比较与其他所有结节的距离\n        for j in range(i + 1, len(sorted_df)):\n            if not keep_mask[j]:\n                continue  # 如果要比较的行已被标记为删除，则跳过\n            # 获取要比较的结节坐标\n            compare = sorted_df.iloc[j]\n            # 计算3D欧氏距离\n            distance = np.sqrt(\n                (current['voxel_coord_x'] - compare['voxel_coord_x']) ** 2 +\n                (current['voxel_coord_y'] - compare['voxel_coord_y']) ** 2 +\n                (current['voxel_coord_z'] - compare['voxel_coord_z']) ** 2\n            )\n            # 如果距离小于阈值，标记为删除\n            if distance < distance_threshold:\n                keep_mask[j] = False\n    # 应用掩码，仅保留未被标记为删除的行\n    reduced_df = sorted_df[keep_mask].reset_index(drop=True)\n    return reduced_df\n\ndef filter_false_positives(nodules_df, ct_data, max_nodules=10):\n    \"\"\"\n    基于解剖学和统计特征过滤假阳性结节\n    \n    Args:\n        nodules_df: 包含结节预测的DataFrame\n        ct_data: CTData对象\n        max_nodules: 每个患者允许的最大结节数量\n\n    Returns:\n        过滤后的结节DataFrame\n    \"\"\"\n    if nodules_df.empty:\n        return nodules_df\n    # 获取肺部掩码\n    lung_mask = ct_data.lung_seg_mask\n    # 1. 限制结节总数\n    if len(nodules_df) > max_nodules:\n        # 只保留概率最高的前N个结节\n        nodules_df = nodules_df.sort_values('prob', ascending=False).head(max_nodules)\n    # 2. 基于位置过滤\n    filtered_rows = []\n    for i, row in nodules_df.iterrows():\n        x, y, z = int(row['voxel_coord_x']), int(row['voxel_coord_y']), int(row['voxel_coord_z'])\n        this_nodule_valid = nodule_valid(ct_data, x, y, z)\n        if this_nodule_valid:\n            # 通过所有检查，保留此结节\n            filtered_rows.append(row)\n    # 创建新的DataFrame\n    filtered_df = pd.DataFrame(filtered_rows)\n    # 3. 基于概率再次过滤\n    # 如果概率低于阈值，移除\n    # high_prob_threshold = 0.95  # 高概率阈值\n    # filtered_df = filtered_df[filtered_df['prob'] >= high_prob_threshold]\n    return filtered_df\n\ndef format_results(results_df, ct_data, patient_id):\n    \"\"\"\n    格式化结果为最终输出的DataFrame\n    \n    Args:\n        results_df: 合并后的结节DataFrame\n        ct_data: CTData对象\n        patient_id: 患者ID\n        \n    Returns:\n        包含结节信息的最终DataFrame\n    \"\"\"\n    # 如果没有结节，返回空的DataFrame\n    if results_df.empty:\n        return pd.DataFrame(columns=['patient_id', 'nodule_id', 'voxel_x', 'voxel_y', 'voxel_z', \n                                     'world_x', 'world_y', 'world_z', 'diameter_mm', 'prob'])\n    # 创建最终结果列表\n    final_results = []\n    # 处理每个结节\n    for i, row in results_df.iterrows():\n        # 设置默认直径为CUBE_SIZE / 2\n        diameter_mm = CUBE_SIZE / 2\n        # 添加结果\n        final_results.append({\n            'patient_id': patient_id,\n            'nodule_id': i + 1,\n            'voxel_x': int(row['voxel_coord_x']),\n            'voxel_y': int(row['voxel_coord_y']),\n            'voxel_z': int(row['voxel_coord_z']),\n            'world_x': row['world_coord_x'],\n            'world_y': row['world_coord_y'],\n            'world_z': row['world_coord_z'],\n            'diameter_mm': diameter_mm,\n            'prob': row['prob']\n        })\n    \n    # 创建DataFrame\n    final_df = pd.DataFrame(final_results)\n    \n    return final_df\n\ndef detect_nodules(file_path, model_path, detect_patient_id=None, device='cuda'):\n    \"\"\"\n    主函数：对CT数据进行结节检测\n    \n    Args:\n        file_path: CT文件或文件夹路径\n        model_path: 模型权重文件路径\n        detect_patient_id: 患者ID，如果为None则使用文件名\n        device: 计算设备 ('cuda' 或 'cpu')\n        \n    Returns:\n        包含结节信息的DataFrame\n    \"\"\"\n    # 设置日志\n    logger = setup_logger()\n    # 如果患者ID为None，则使用文件名\n    if detect_patient_id is None:\n        if os.path.isfile(file_path):\n            detect_patient_id = os.path.splitext(os.path.basename(file_path))[0]\n        else:\n            detect_patient_id = os.path.basename(file_path)\n    logger.info(f\"开始处理患者 {detect_patient_id} 的CT数据\")\n    try:\n        # 加载CT数据\n        logger.info(f\"加载CT数据: {file_path}\")\n        ct_data = load_ct_data(file_path)\n        # 加载模型\n        logger.info(f\"加载模型: {model_path}\")\n        model, device = load_model(model_path, device)\n        # 扫描CT数据\n        results_df = scan_ct_data(ct_data, model, device, logger)\n        # 合并重叠结节\n        logger.info(\"合并重叠结节...\")\n        reduced_df = reduce_overlapping_nodules(results_df)\n        logger.info(f\"合并后的结节数量: {len(reduced_df)}\")\n        # 过滤假阳性\n        logger.info(\"过滤假阳性结节...\")\n        filtered_df = filter_false_positives(reduced_df, ct_data)\n        logger.info(f\"过滤后的结节数量: {len(filtered_df)}\")\n        # 格式化结果\n        final_df = format_results(filtered_df, ct_data, patient_id)\n        logger.info(f\"检测完成，找到 {len(final_df)} 个结节\")\n        return final_df\n    except Exception as e:\n        logger.error(f\"检测过程中出错: {str(e)}\", exc_info=True)\n        raise\n    \nif __name__ == \"__main__\":\n    test_mhd = \"H:/luna16/subset8/1.3.6.1.4.1.14519.5.2.1.6279.6001.149041668385192796520281592139.mhd\"\n    model_path = \"../training/pytorch_checkpoints/best_model.pth\"\n    threshold = 0.7\n    patient_id = \"1.3.6.1.4.1.14519.5.2.1.6279.6001.149041668385192796520281592139\"\n    detect_result_csv = \"./c3d_classify_result-%s.csv\" %patient_id\n    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n    # 运行检测\n    result_df = detect_nodules(test_mhd, model_path, None, device)\n    # 保存结果\n    result_df.to_csv(detect_result_csv, index=False, encoding=\"utf-8\")"
  },
  {
    "path": "deploy/backend/preprocessing/__init__.py",
    "content": ""
  },
  {
    "path": "deploy/backend/preprocessing/luna16_invalid_nodule_filter.py",
    "content": "## 去掉 Luna2016 候选结节数据中 有问题的标注数据 以及 用在预测过程中的 错误结节\nimport numpy as np\n\ndef nodule_valid(ct_data, voxel_coord_x, voxel_coord_y,voxel_coord_z):\n    \"\"\"\n        判定当前结节是否 可以用来做训练cube 或者 扫描得到的cube 是否\n    :param ct_data:         已经转换为0-255 ，并且已经抽取到肺部区域数据的 CTData类\n    :param voxel_coord_x:   当前要判定的 cube的坐标中心位置\n    :param voxel_coord_y:\n    :param voxel_coord_z:\n    :return:                当前结节是否可用 True(可用) / False (不可用)\n    \"\"\"\n    lung_mask = ct_data.lung_seg_mask\n    # 检查坐标是否在肺部边界内\n    if (voxel_coord_z < 0 or voxel_coord_z >= lung_mask.shape[0] or\n            voxel_coord_y < 0 or voxel_coord_y >= lung_mask.shape[1] or\n            voxel_coord_x < 0 or voxel_coord_x >= lung_mask.shape[2]):\n        return False\n    # 获取周围半径为5个体素的区域\n    z_min = max(0, voxel_coord_z - 5)\n    z_max = min(lung_mask.shape[0], voxel_coord_z + 6)\n    y_min = max(0, voxel_coord_y - 5)\n    y_max = min(lung_mask.shape[1], voxel_coord_y + 6)\n    x_min = max(0, voxel_coord_x - 5)\n    x_max = min(lung_mask.shape[2], voxel_coord_x + 6)\n    # 提取周围区域的肺部掩码\n    neighborhood_mask = lung_mask[z_min:z_max, y_min:y_max, x_min:x_max]\n    # 计算肺部组织占比\n    lung_ratio = np.mean(neighborhood_mask)\n    # 如果周围区域肺部组织占比太低，可能是假阳性\n    if lung_ratio < 0.5:\n        return False\n\n    # 检查是否在肺部边缘\n    # 计算当前点在肺部掩码中的位置\n    if (0 < voxel_coord_z < lung_mask.shape[0] - 1 and\n            0 < voxel_coord_y < lung_mask.shape[1] - 1 and\n            0 < voxel_coord_x < lung_mask.shape[2] - 1):\n        # 计算6-邻域（上下左右前后）中肺部体素的数量\n        neighbors = [\n            lung_mask[voxel_coord_z - 1, voxel_coord_y, voxel_coord_x],\n            lung_mask[voxel_coord_z + 1, voxel_coord_y, voxel_coord_x],\n            lung_mask[voxel_coord_z, voxel_coord_y - 1, voxel_coord_x],\n            lung_mask[voxel_coord_z, voxel_coord_y + 1, voxel_coord_x],\n            lung_mask[voxel_coord_z, voxel_coord_y, voxel_coord_x - 1],\n            lung_mask[voxel_coord_z, voxel_coord_y, voxel_coord_x + 1]\n        ]\n        # 如果邻域中有过多非肺部体素，说明这可能是在肺部边缘\n        if sum(neighbors) < 4:\n            return False\n    return True"
  },
  {
    "path": "deploy/backend/util/__init__.py",
    "content": ""
  },
  {
    "path": "deploy/backend/util/dicom_util.py",
    "content": "import os\nimport glob\nimport pydicom\nimport numpy as np\nimport cv2\nfrom tqdm import tqdm\n\nfrom util.seg_util import get_segmented_lungs,normalize_hu_values\nfrom util.image_util import rescale_patient_images\n\n\ndef is_dicom_file(filename):\n    '''\n       if current file is a dicom file\n    :param filename:      file need to be judged\n    :return:\n    '''\n    file_stream = open(filename, 'rb')\n    file_stream.seek(128)\n    data = file_stream.read(4)\n    file_stream.close()\n    if data == b'DICM':\n        return True\n    return False\n\ndef get_dicom_thickness(dicom_slices):\n    \"\"\"\n        计算切片厚度\n    :param dicom_slices:    dicom 读取的 dicom数据\n    :return:\n    \"\"\"\n    if len(dicom_slices) > 1:\n        try:\n            slice_thickness = abs(dicom_slices[0].ImagePositionPatient[2] - dicom_slices[1].ImagePositionPatient[2])\n        except:\n            try:\n                slice_thickness = abs(dicom_slices[0].SliceLocation - dicom_slices[1].SliceLocation)\n            except:\n                # 如果无法计算，尝试从SliceThickness标签中获取\n                try:\n                    slice_thickness = float(dicom_slices[0].SliceThickness)\n                except:\n                    print(\"警告: 无法确定切片厚度，使用默认值1.0mm\")\n                    slice_thickness = 1.0\n    else:\n        try:\n            slice_thickness = float(dicom_slices[0].SliceThickness)\n        except:\n            print(\"警告: 只有一个切片，无法计算切片厚度，使用默认值1.0mm\")\n            slice_thickness = 1.0\n    return slice_thickness\n\ndef load_dicom_slices(dicom_path):\n    \"\"\"\n        load dicom file path and stack into list\n\n    :param dicom_path:     a dicom path\n    :return:            dicom list\n    \"\"\"\n    dicom_files = []\n    for root, _, files in os.walk(dicom_path):\n        for file in files:\n            if file.lower().endswith(('.dcm', '.dicom')):\n                real_file = os.path.join(dicom_path, root, file)\n                current_if_dicom = is_dicom_file(real_file)\n                if current_if_dicom:\n                    dicom_files.append(real_file)\n    if not dicom_files:\n        raise ValueError(f\"在路径 {dicom_path} 中未找到DICOM文件\")\n    # 加载所有切片\n    slices = []\n    for file in dicom_files:\n        try:\n            ds = pydicom.dcmread(file)\n            slices.append(ds)\n        except Exception as e:\n            print(f\"无法读取DICOM文件 {file}: {e}\")\n    # # 按照Z轴位置排序切片\n    slices.sort(key=lambda x: int(x.InstanceNumber))\n    slice_thickness = get_dicom_thickness(slices)\n    for s in slices:\n        s.SliceThickness = slice_thickness\n    return slices\n\ndef get_pixels_hu(slices):\n    '''\n        transfer dicom array to pixel array,and remove border(HU==-2000)\n\n    :param slices:  dicom list\n    :return:        pixel array of one patient's dicom\n    '''\n    image = np.stack([s.pixel_array for s in slices])\n    image = image.astype(np.int16)\n    image[image == -2000] = 0\n    for slice_number in range(len(slices)):\n        intercept = slices[slice_number].RescaleIntercept\n        slope = slices[slice_number].RescaleSlope\n        if slope != 1:\n            image[slice_number] = slope * image[slice_number].astype(np.float64)\n            image[slice_number] = image[slice_number].astype(np.int16)\n        image[slice_number] += np.int16(intercept)\n    return np.array(image, dtype=np.int16)\n\n\ndef getinfo_dicom(dicom_path):\n    print('dicom_path: ', dicom_path)\n    slices = load_dicom_slices(dicom_path)\n    print(type(slices[0]), slices[0].ImagePositionPatient)\n    print(len(slices), \"\\t\", slices[0].SliceThickness, \"\\t\", slices[0].PixelSpacing)\n    print(\"Orientation: \", slices[0].ImageOrientationPatient)\n    #assert slices[0].ImageOrientationPatient == [1.000000, 0.000000, 0.000000, 0.000000, 1.000000, 0.000000]\n    pixels = get_pixels_hu(slices)\n    image = pixels\n    print(image.shape)\n\n    invert_order = slices[1].ImagePositionPatient[2] > slices[0].ImagePositionPatient[2]\n    print(\"Invert order: \", invert_order, \" - \", slices[1].ImagePositionPatient[2], \",\",\n          slices[0].ImagePositionPatient[2])\n\n    pixel_spacing = slices[0].PixelSpacing\n    pixel_spacing.append(slices[0].SliceThickness)\n    # save dicom source image size\n    dicom_size = [image.shape[0], image.shape[1], image.shape[2]]\n\n    return pixel_spacing, dicom_size, invert_order\n\ndef extract_dicom_images_patient(dicom_path, target_dir):\n    slices = load_dicom_slices(dicom_path)\n    assert slices[0].ImageOrientationPatient == [1.000000, 0.000000, 0.000000, 0.000000, 1.000000, 0.000000]\n    pixels = get_pixels_hu(slices)\n    image = pixels\n    invert_order = slices[1].ImagePositionPatient[2] > slices[0].ImagePositionPatient[2]\n    pixel_spacing = slices[0].PixelSpacing\n    pixel_spacing.append(slices[0].SliceThickness)\n    # save dicom source image size\n    dicom_size = [image.shape[0], image.shape[1], image.shape[2]]\n    image = rescale_patient_images(image, pixel_spacing)\n    png_size = [image.shape[0], image.shape[1], image.shape[2]]\n    if not invert_order:\n        image = np.flipud(image)\n    if not os.path.exists(target_dir):\n        os.mkdir(target_dir)\n    else:\n        print(\"png dir already exists, return directly\")\n        return pixel_spacing, dicom_size, png_size, invert_order\n    png_files = glob.glob(target_dir + \"*.png\")\n    for file in png_files:\n        os.remove(file)\n    for i in tqdm(range(image.shape[0])):\n        img_path = patient_dir + \"/img_\" + str(i).rjust(4, '0') + \"_i.png\"\n        org_img = image[i]\n        img, mask = get_segmented_lungs(org_img.copy())\n        org_img = normalize_hu_values(org_img)\n        cv2.imwrite(img_path, org_img * 255)\n        cv2.imwrite(img_path.replace(\"_i.png\", \"_m.png\"), mask * 255)\n    return pixel_spacing, dicom_size, png_size,invert_order\n\n"
  },
  {
    "path": "deploy/backend/util/image_util.py",
    "content": "from typing import Tuple\n\nimport cv2\nimport os\nimport numpy\nimport glob\nimport random\nimport numpy as np\nfrom scipy import ndimage\n\ndef get_normalized_img_unit8(img):\n    img = img.astype(numpy.float)\n    min = img.min()\n    max = img.max()\n    img -= min\n    img /= max - min\n    img *= 255\n    res = img.astype(numpy.uint8)\n    return res\n\n\ndef load_patient_images(png_path, wildcard=\"*.*\", exclude_wildcards=[]):\n    print(\"png path is\\t\",png_path)\n    src_dir = png_path\n    src_img_paths = glob.glob(src_dir +'/'+ wildcard)\n    for exclude_wildcard in exclude_wildcards:\n        exclude_img_paths = glob.glob(src_dir + exclude_wildcard)\n        src_img_paths = [im for im in src_img_paths if im not in exclude_img_paths]\n    src_img_paths.sort()\n\n    images = [cv2.imread(img_path, cv2.IMREAD_GRAYSCALE) for img_path in src_img_paths]\n    images = [im.reshape((1, ) + im.shape) for im in images]\n    res = numpy.vstack(images)\n    return res\n\n\ndef draw_overlay(png_path: str, p_x: float, p_y: float, p_z: float, index: str,  BOX_size:int = 20) -> None:\n    \"\"\"\n    在图像上绘制覆盖层\n    Args:\n        png_path: PNG图像路径\n        p_x: X坐标（百分比）\n        p_y: Y坐标（百分比）\n        p_z: Z坐标（百分比）\n        index: 索引标识\n        :param BOX_size:\n    \"\"\"\n    patient_img = load_patient_images(png_path + \"/png/\", \"*_i.png\", [])\n    z = int(p_z * patient_img.shape[0])\n    y = int(p_y * patient_img.shape[1])\n    x = int(p_x * patient_img.shape[2])\n     # 包围盒大小\n    x1 = x - BOX_size\n    y1 = y - BOX_size\n    x2 = x + BOX_size\n    y2 = y + BOX_size\n    target_img = patient_img[z, :, :]\n    cv2.rectangle(target_img, (x1, y1), (x2, y2), (255, 0, 0), 1)\n    cv2.imwrite(png_path + \"/\" + index + \".png\", target_img)\n\ndef prepare_image_for_net3D(img,MEAN_PIXEL_VALUE = 41):\n    '''\n        normalization of image (average and zero center)\n\n    :param img:               image to be normalization\n    :param MEAN_PIXEL_VALUE:\n    :return:\n    '''\n    img = img.astype(numpy.float32)\n    img -= MEAN_PIXEL_VALUE\n    img /= 255.\n    img = img.reshape(1, img.shape[0], img.shape[1], img.shape[2], 1)\n    return img\n\n\ndef move_png2dir(target_dir):\n    import shutil\n    first_dir = []\n    for path in os.listdir(target_dir):\n        if os.path.isdir(os.path.join(target_dir,path)):\n            first_dir.append(os.path.join(target_dir,path))\n    for d in first_dir:\n        tmp_path = []\n        for file in os.listdir(d):\n            tmp_file_path = os.path.join(d,file)\n            png_path = os.path.join(d,'png')\n            if not os.path.exists(png_path):\n                os.mkdir(png_path)\n            if tmp_file_path.endswith(\".png\"):\n                shutil.move(tmp_file_path,os.path.join(png_path,file))\n                print(\"move file from %s  to   %s \" %(tmp_file_path,os.path.join(png_path,file)))\n\ndef rescale_patient_images(images_zyx, org_spacing_xyz, target_voxel_mm =1.0, is_mask_image=False, verbose=False):\n    '''\n                rescale a 3D image to specified size\n\n    :param images_zyx:              source image\n    :param org_spacing_xyz:\n    :param target_voxel_mm:\n    :param is_mask_image:\n    :param verbose:\n    :return:\n    '''\n    if verbose:\n        print(\"Spacing: \", org_spacing_xyz)\n        print(\"Shape: \", images_zyx.shape)\n\n    # print \"Resizing dim z\"\n    resize_x = 1.0\n    resize_y = float(org_spacing_xyz[2]) / float(target_voxel_mm)\n    interpolation = cv2.INTER_NEAREST if is_mask_image else cv2.INTER_LINEAR\n    res = cv2.resize(images_zyx, dsize=None, fx=resize_x, fy=resize_y, interpolation=interpolation)  # opencv assumes y, x, channels umpy array, so y = z pfff\n    res = res.swapaxes(0, 2)\n    res = res.swapaxes(0, 1)\n    # print \"Shape: \", res.shape\n    resize_x = float(org_spacing_xyz[0]) / float(target_voxel_mm)\n    resize_y = float(org_spacing_xyz[1]) / float(target_voxel_mm)\n\n    # cv2 can handle max 512 channels..\n    if res.shape[2] > 512:\n        res = res.swapaxes(0, 2)\n        res1 = res[:256]\n        res2 = res[256:]\n        res1 = res1.swapaxes(0, 2)\n        res2 = res2.swapaxes(0, 2)\n        res1 = cv2.resize(res1, dsize=None, fx=resize_x, fy=resize_y, interpolation=interpolation)\n        res2 = cv2.resize(res2, dsize=None, fx=resize_x, fy=resize_y, interpolation=interpolation)\n        res1 = res1.swapaxes(0, 2)\n        res2 = res2.swapaxes(0, 2)\n        res = numpy.vstack([res1, res2])\n        res = res.swapaxes(0, 2)\n    else:\n        res = cv2.resize(res, dsize=None, fx=resize_x, fy=resize_y, interpolation=interpolation)\n    res = res.swapaxes(0, 2)\n    res = res.swapaxes(2, 1)\n    if verbose:\n        print(\"Shape after: \", res.shape)\n    return res\n\n\ndef rescale_patient_images2(images_zyx, target_shape, verbose=False):\n    if verbose:\n        print(\"Target: \", target_shape)\n        print(\"Shape: \", images_zyx.shape)\n\n    # print \"Resizing dim z\"\n    resize_x = 1.0\n    interpolation = cv2.INTER_NEAREST if False else cv2.INTER_LINEAR\n    res = cv2.resize(images_zyx, dsize=(target_shape[1], target_shape[0]), interpolation=interpolation)  # opencv assumes y, x, channels umpy array, so y = z pfff\n    # print \"Shape is now : \", res.shape\n\n    res = res.swapaxes(0, 2)\n    res = res.swapaxes(0, 1)\n\n    # cv2 can handle max 512 channels..\n    if res.shape[2] > 512:\n        res = res.swapaxes(0, 2)\n        res1 = res[:256]\n        res2 = res[256:]\n        res1 = res1.swapaxes(0, 2)\n        res2 = res2.swapaxes(0, 2)\n        res1 = cv2.resize(res1, dsize=(target_shape[2], target_shape[1]), interpolation=interpolation)\n        res2 = cv2.resize(res2, dsize=(target_shape[2], target_shape[1]), interpolation=interpolation)\n        res1 = res1.swapaxes(0, 2)\n        res2 = res2.swapaxes(0, 2)\n        res = numpy.vstack([res1, res2])\n        res = res.swapaxes(0, 2)\n    else:\n        res = cv2.resize(res, dsize=(target_shape[2], target_shape[1]), interpolation=interpolation)\n\n    res = res.swapaxes(0, 2)\n    res = res.swapaxes(2, 1)\n    if verbose:\n        print(\"Shape after: \", res.shape)\n    return res\n\n\ndef resize_image(image: np.ndarray, new_shape: Tuple[int, ...]) -> np.ndarray:\n    \"\"\"\n    调整图像大小\n\n    Args:\n        image: 输入图像\n        new_shape: 新形状\n\n    Returns:\n        np.ndarray: 调整大小后的图像\n    \"\"\"\n    # 处理单通道或多通道图像\n    if len(image.shape) == 3 and len(new_shape) == 2:\n        # 处理3D图像调整为2D\n        resized_image = np.zeros((image.shape[0], new_shape[0], new_shape[1]))\n        for i in range(image.shape[0]):\n            resized_image[i] = cv2.resize(image[i], (new_shape[1], new_shape[0]))\n        return resized_image\n    elif len(image.shape) == 2 and len(new_shape) == 2:\n        # 处理2D图像\n        return cv2.resize(image, (new_shape[1], new_shape[0]))\n    else:\n        # 处理任意维度图像\n        resize_factor = tuple(n / o for n, o in zip(new_shape, image.shape))\n        return ndimage.zoom(image, resize_factor, mode='nearest')\n\ndef cv_flip(img,cols,rows,degree):\n    '''\n        flip image by degree\n\n    :param img:         image array to be fliped\n    :param cols:        width of image\n    :param rows:        height of image\n    :param degree:      degree to flip\n    :return:\n    '''\n    M = cv2.getRotationMatrix2D((cols / 2, rows /2), degree, 1.0)\n    dst = cv2.warpAffine(img, M, (cols, rows))\n    return dst\n\n\ndef random_rotate_img(img, chance, min_angle, max_angle):\n    '''\n        random rotation an image\n\n    :param img:         image to be rotated\n    :param chance:      random probability\n    :param min_angle:   min angle to rotate\n    :param max_angle:   max angle to rotate\n    :return:            image after random rotated\n    '''\n    import cv2\n    if random.random() > chance:\n        return img\n    if not isinstance(img, list):\n        img = [img]\n\n    angle = random.randint(min_angle, max_angle)\n    center = (img[0].shape[0] / 2, img[0].shape[1] / 2)\n    rot_matrix = cv2.getRotationMatrix2D(center, angle, scale=1.0)\n\n    res = []\n    for img_inst in img:\n        img_inst = cv2.warpAffine(img_inst, rot_matrix, dsize=img_inst.shape[:2], borderMode=cv2.BORDER_CONSTANT)\n        res.append(img_inst)\n    if len(res) == 0:\n        res = res[0]\n    return res\n\n\ndef random_flip_img(img, horizontal_chance=0, vertical_chance=0):\n    '''\n        random flip image,both on horizontal and vertical\n\n    :param img:                 image to be flipped\n    :param horizontal_chance:   flip probability to flipped on horizontal direction\n    :param vertical_chance:     flip probability to flipped on vertical  direction\n    :return:                    image after flipped\n    '''\n    import cv2\n    flip_horizontal = False\n    if random.random() < horizontal_chance:\n        flip_horizontal = True\n\n    flip_vertical = False\n    if random.random() < vertical_chance:\n        flip_vertical = True\n\n    if not flip_horizontal and not flip_vertical:\n        return img\n\n    flip_val = 1\n    if flip_vertical:\n        flip_val = -1 if flip_horizontal else 0\n\n    if not isinstance(img, list):\n        res = cv2.flip(img, flip_val)  # 0 = X axis, 1 = Y axis,  -1 = both\n    else:\n        res = []\n        for img_item in img:\n            img_flip = cv2.flip(img_item, flip_val)\n            res.append(img_flip)\n    return res\n\n\ndef random_scale_img(img, xy_range, lock_xy=False):\n    if random.random() > xy_range.chance:\n        return img\n\n    if not isinstance(img, list):\n        img = [img]\n\n    import cv2\n    scale_x = random.uniform(xy_range.x_min, xy_range.x_max)\n    scale_y = random.uniform(xy_range.y_min, xy_range.y_max)\n    if lock_xy:\n        scale_y = scale_x\n\n    org_height, org_width = img[0].shape[:2]\n    xy_range.last_x = scale_x\n    xy_range.last_y = scale_y\n\n    res = []\n    for img_inst in img:\n        scaled_width = int(org_width * scale_x)\n        scaled_height = int(org_height * scale_y)\n        scaled_img = cv2.resize(img_inst, (scaled_width, scaled_height), interpolation=cv2.INTER_CUBIC)\n        if scaled_width < org_width:\n            extend_left = (org_width - scaled_width) / 2\n            extend_right = org_width - extend_left - scaled_width\n            scaled_img = cv2.copyMakeBorder(scaled_img, 0, 0, extend_left, extend_right, borderType=cv2.BORDER_CONSTANT)\n            scaled_width = org_width\n\n        if scaled_height < org_height:\n            extend_top = (org_height - scaled_height) / 2\n            extend_bottom = org_height - extend_top - scaled_height\n            scaled_img = cv2.copyMakeBorder(scaled_img, extend_top, extend_bottom, 0, 0, borderType=cv2.BORDER_CONSTANT)\n            scaled_height = org_height\n\n        start_x = (scaled_width - org_width) / 2\n        start_y = (scaled_height - org_height) / 2\n        tmp = scaled_img[start_y: start_y + org_height, start_x: start_x + org_width]\n        res.append(tmp)\n\n    return res\n\n\nclass XYRange:\n    def __init__(self, x_min, x_max, y_min, y_max, chance=1.0):\n        self.chance = chance\n        self.x_min = x_min\n        self.x_max = x_max\n        self.y_min = y_min\n        self.y_max = y_max\n        self.last_x = 0\n        self.last_y = 0\n\n    def get_last_xy_txt(self):\n        res = \"x_\" + str(int(self.last_x * 100)).replace(\"-\", \"m\") + \"-\" + \"y_\" + str(int(self.last_y * 100)).replace(\n            \"-\", \"m\")\n        return res\n\n\ndef random_translate_img(img, xy_range, border_mode=\"constant\"):\n    if random.random() > xy_range.chance:\n        return img\n    import cv2\n    if not isinstance(img, list):\n        img = [img]\n\n    org_height, org_width = img[0].shape[:2]\n    translate_x = random.randint(xy_range.x_min, xy_range.x_max)\n    translate_y = random.randint(xy_range.y_min, xy_range.y_max)\n    trans_matrix = numpy.float32([[1, 0, translate_x], [0, 1, translate_y]])\n\n    border_const = cv2.BORDER_CONSTANT\n    if border_mode == \"reflect\":\n        border_const = cv2.BORDER_REFLECT\n\n    res = []\n    for img_inst in img:\n        img_inst = cv2.warpAffine(img_inst, trans_matrix, (org_width, org_height), borderMode=border_const)\n        res.append(img_inst)\n    if len(res) == 1:\n        res = res[0]\n    xy_range.last_x = translate_x\n    xy_range.last_y = translate_y\n    return res\n\n\ndef data_augmentation(image: np.ndarray, augment_type: str = 'random') -> np.ndarray:\n    \"\"\"\n    对图像进行数据增强\n\n    Args:\n        image: 输入图像\n        augment_type: 增强类型，可选'random', 'flip', 'rotate', 'shift'\n\n    Returns:\n        np.ndarray: 增强后的图像\n    \"\"\"\n    if augment_type == 'random':\n        # 随机选择一种增强方式\n        augment_choices = ['flip', 'rotate', 'shift', 'none']\n        choice = np.random.choice(augment_choices)\n\n        if choice == 'flip':\n            return data_augmentation(image, 'flip')\n        elif choice == 'rotate':\n            return data_augmentation(image, 'rotate')\n        elif choice == 'shift':\n            return data_augmentation(image, 'shift')\n        else:\n            return image\n\n    elif augment_type == 'flip':\n        # 随机翻转\n        axis = np.random.randint(0, image.ndim)\n        return np.flip(image, axis=axis)\n\n    elif augment_type == 'rotate':\n        # 随机旋转\n        if image.ndim == 2:\n            angle = np.random.randint(0, 360)\n            return ndimage.rotate(image, angle, reshape=False, mode='nearest')\n        else:\n            # 3D旋转\n            axes = tuple(np.random.choice(range(image.ndim), size=2, replace=False))\n            angle = np.random.randint(0, 360)\n            return ndimage.rotate(image, angle, axes=axes, reshape=False, mode='nearest')\n\n    elif augment_type == 'shift':\n        # 随机平移\n        shift = np.random.randint(-5, 6, size=image.ndim)\n        return ndimage.shift(image, shift, mode='nearest')\n\n    return image"
  },
  {
    "path": "deploy/backend/util/mhd_util.py",
    "content": "import os\nimport ntpath\nimport SimpleITK\nimport numpy as np\nimport pandas as pd\nimport cv2\n\nfrom data.dataclass.CTData import CTData\nfrom util.seg_util import normalize_hu_values,get_segmented_lungs\nfrom util import image_util\nfrom constant import tianchi\n\nTARGET_VOXEL_MM = 1.0\nMHD_INFO_HEAD = \"patient_id,shape_0,shape_1,shape_2,origin_x,origin_y,origin_z,direction_z(1_-1),\" \\\n                \"spacing_x,spacing_y,spacing_z,rescale_x,rescale_y,rescale_z\"\n\ndef get_all_mhd_file(BASE_DATA_DIR,base_head,max):\n    \"\"\"\n       get all mhd file list ,tianchi mhd file consist of train_subset00,train_subset01,... test_subset00,test_subset01,..\n\n    :param base_head:       'train' or 'test',or 'val', to construct train_subset00,test_subset01,val_subset02...\n    :param max:             the max suffix of path ,such as train_subset09, then max=09\n    :return:                all mhd file list\n    \"\"\"\n    mhd_files = []\n    for index in range(0,max):\n        if index<10:\n            index = \"0\"+str(index)\n        else:\n            index =str(index)\n        sub_path = os.path.join(BASE_DATA_DIR,base_head+\"_subset\"+index)\n        for name in os.listdir(sub_path):\n            if name.endswith(\".mhd\"):\n                mhd_files.append(os.path.join(sub_path,name))\n    return mhd_files\n\n\ndef get_luna16_mhd_file(mhd_root):\n    \"\"\"\n       get all mhd file list ,tianchi mhd file consist of train_subset00,train_subset01,... test_subset00,test_subset01,..\n\n    :param mhd_root:       'train' or 'test',or 'val', to construct train_subset00,test_subset01,val_subset02...\n    :return:                all mhd file list\n    \"\"\"\n    mhd_files = []\n    for root, _, files in os.walk(mhd_root):\n        for file in files:\n            if file.lower().endswith('.mhd'):\n                real_file = os.path.join(mhd_root, root, file)\n                mhd_files.append(real_file)\n    return mhd_files\ndef read_csv_to_pandas(mhd_info,col_sepator ='\\t'):\n    \"\"\"\n       read csv information into pandas dataframe\n\n    :param mhd_info:      csv file of mhd file\n    :param col_sepator:  sepator string of columns\n    :return:\n    \"\"\"\n    with open(mhd_info, 'r') as csv:\n        head = csv.readline().split(\",\")  # get header of csv\n        indexs = []\n        lines = csv.readlines()\n        list = []\n        for line in lines:\n            list.append(line.split(col_sepator))\n            indexs.append(line.split(col_sepator)[0])  #the first element should be id of patient\n        df = pd.DataFrame(data=list, columns=head,index=indexs)\n        return df\n\ndef extract_image_from_mhd(mhd_file_path,png_save_path_root =None):\n    \"\"\"\n        extract image from mhd file and return mhd information\n\n    :param mhd_file_path:       mhd file to extract\n    :param png_save_path_root:  file path where to save the extracted image (both image and mask image will be saved)\n                                ,if this param is None means only mhd information returns,no image extracted\n    :return:\n    \"\"\"\n    mhd_info = []\n    patient_id = ntpath.basename(mhd_file_path).replace(\".mhd\", \"\")\n    print(\"Patient: \", patient_id)\n    mhd_info.append(patient_id)\n    if not os.path.exists(png_save_path_root):\n        os.mkdir(png_save_path_root)\n    dst_dir = png_save_path_root+'/' + patient_id + \"/\"\n    if not os.path.exists(dst_dir):\n        os.mkdir(dst_dir)\n\n    itk_img = SimpleITK.ReadImage(mhd_file_path)\n    img_array = SimpleITK.GetArrayFromImage(itk_img)\n    print(\"Img array: \", img_array.shape)\n    (shape0,shape1,shape2) = img_array.shape\n    mhd_info.append(str(shape2))\n    mhd_info.append(str(shape1))\n    mhd_info.append(str(shape0))\n\n    origin = np.array(itk_img.GetOrigin())      # x,y,z  Origin in world coordinates (mm)\n    print(\"Origin (x,y,z): \", origin)\n    mhd_info.append(str(origin[0]))\n    mhd_info.append(str(origin[1]))\n    mhd_info.append(str(origin[2]))\n\n    direction = np.array(itk_img.GetDirection())      # x,y,z  Origin in world coordinates (mm)\n    print(\"Direction: \", direction)\n    direct_arow = [1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0]\n    if direction.tolist() == direct_arow:\n        print(\"positive direction..\")\n        mhd_info.append(str(1))\n    else:\n        mhd_info.append(str(-1))\n\n    spacing = np.array(itk_img.GetSpacing())    # spacing of voxels in world coor. (mm)\n    print(\"Spacing (x,y,z): \", spacing)\n    mhd_info.append(str(spacing[0]))\n    mhd_info.append(str(spacing[1]))\n    mhd_info.append(str(spacing[2]))\n\n    rescale = spacing /TARGET_VOXEL_MM\n    print(\"Rescale: \", rescale)\n    mhd_info.append(str(rescale[0]))\n    mhd_info.append(str(rescale[1]))\n    mhd_info.append(str(rescale[2]))\n\n    if png_save_path_root is None:              # get mhd information only\n        return mhd_info\n\n    if not os.path.exists(dst_dir):\n        if img_array.shape[1]== 512:\n            img_array = image_util.rescale_patient_images(img_array, spacing, TARGET_VOXEL_MM)\n        img_list = []\n        for i in range(img_array.shape[0]):\n            img = img_array[i]\n            seg_img, mask = get_segmented_lungs(img.copy())\n            img_list.append(seg_img)\n            img = normalize_hu_values(img)\n            cv2.imwrite(dst_dir + \"img_\" + str(i).rjust(4, '0') + \"_i.png\", img * 255)\n            cv2.imwrite(dst_dir + \"img_\" + str(i).rjust(4, '0') + \"_m.png\", mask * 255)\n    return mhd_info\n\n"
  },
  {
    "path": "deploy/backend/util/seg_util.py",
    "content": "import numpy as np\nimport matplotlib.pyplot as plt\nfrom scipy import ndimage as ndi\nfrom skimage.filters import roberts\nfrom skimage.measure import regionprops, label\nfrom skimage.morphology import binary_closing, disk, binary_erosion\nfrom skimage.segmentation import clear_border\n\n\ndef normalize_hu_values(image: np.ndarray, min_bound: int = -1000, max_bound: int = 400) -> np.ndarray:\n    \"\"\"\n    归一化HU值到[0,1]范围\n\n    Args:\n        image: 输入图像\n        min_bound: 最小HU值\n        max_bound: 最大HU值\n\n    Returns:\n        np.ndarray: 归一化后的图像\n    \"\"\"\n    image = (image - min_bound) / (max_bound - min_bound)\n    image[image > 1] = 1.\n    image[image < 0] = 0.\n    return image\n\ndef get_segmented_lungs(im, plot=False):\n    '''\n        extract lung ROI from pixel array\n\n    :param im:      a patient's piexl array\n    :param plot:    if plot when segment\n    :return:\n    '''\n    # Step 1: Convert into a binary image.\n    binary = im < -400\n    # Step 2: Remove the blobs connected to the border of the image.\n    cleared = clear_border(binary)\n    # Step 3: Label the image.\n    label_image = label(cleared)\n    # Step 4: Keep the labels with 2 largest areas.\n    areas = [r.area for r in regionprops(label_image)]\n    areas.sort()\n    if len(areas) > 2:\n        for region in regionprops(label_image):\n            if region.area < areas[-2]:\n                for coordinates in region.coords:\n                       label_image[coordinates[0], coordinates[1]] = 0\n    binary = label_image > 0\n    # Step 5: Erosion operation with a disk of radius 2. This operation is seperate the lung nodules attached to the blood vessels.\n    selem = disk(2)\n    binary = binary_erosion(binary, selem)\n    # Step 6: Closure operation with a disk of radius 10. This operation is    to keep nodules attached to the lung wall.\n    selem = disk(10) # CHANGE BACK TO 10\n    binary = binary_closing(binary, selem)\n    # Step 7: Fill in the small holes inside the binary mask of lungs.\n    edges = roberts(binary)\n    binary = ndi.binary_fill_holes(edges)\n    # Step 8: Superimpose the binary mask on the input image.\n    get_high_vals = binary == 0\n    im[get_high_vals] = -2000\n    if plot:\n        plt.figure(figsize=(10, 10))\n        plt.subplot(1, 2, 1)\n        plt.imshow(binary, cmap='gray')\n        plt.title('Lung Mask')\n        plt.subplot(1, 2, 2)\n        plt.imshow(im, cmap='gray')\n        plt.title('Masked Image')\n        plt.show()\n    return im, binary"
  },
  {
    "path": "deploy/backend/utils.py",
    "content": "import os\nimport sys\nimport logging\nimport numpy as np\nimport SimpleITK as sitk\nimport tempfile\nimport zipfile\nimport shutil\nimport pydicom\nfrom scipy import ndimage\n\n# 配置日志\nlogging.basicConfig(level=logging.INFO)\nlogger = logging.getLogger(__name__)\n\n# 尝试导入CTData类，如果不可用则创建一个简化版\ntry:\n    # 添加项目根目录到系统路径\n    sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))\n    \n    # 导入原始CTData类\n    from data.dataclass.CTData import CTData\nexcept ImportError:\n    # 如果无法导入，定义一个简化版的CTData类\n    class CTData:\n        \"\"\"CT数据类，用于加载和处理CT图像\"\"\"\n        \n        def __init__(self):\n            self.img = None  # 原始图像数据\n            self.origin = None  # 图像原点\n            self.spacing = None  # 图像间距\n            self.is_normalized = False  # 是否已标准化\n            self.has_lung_seg = False  # 是否已进行肺部分割\n            self.lung_seg_img = None  # 肺部分割图像\n            self.lung_mask = None  # 肺部掩码\n            \n        @classmethod\n        def from_dicom(cls, dicom_path):\n            \"\"\"从DICOM文件夹加载CT数据\"\"\"\n            ct_data = cls()\n            \n            try:\n                logger.info(f\"从DICOM加载: {dicom_path}\")\n                \n                # 处理DICOM目录或zip文件\n                temp_dir = None\n                \n                if os.path.isfile(dicom_path) and dicom_path.endswith('.zip'):\n                    # 创建临时目录解压缩\n                    temp_dir = tempfile.mkdtemp()\n                    with zipfile.ZipFile(dicom_path, 'r') as zip_ref:\n                        zip_ref.extractall(temp_dir)\n                    dicom_path = temp_dir\n                \n                # 读取DICOM序列\n                reader = sitk.ImageSeriesReader()\n                dicom_names = reader.GetGDCMSeriesFileNames(dicom_path)\n                \n                if not dicom_names:\n                    raise ValueError(f\"在{dicom_path}中未找到DICOM文件\")\n                \n                reader.SetFileNames(dicom_names)\n                ct_sitk_img = reader.Execute()\n                \n                # 清理临时目录\n                if temp_dir and os.path.exists(temp_dir):\n                    shutil.rmtree(temp_dir)\n                \n                # 转换为numpy数组\n                ct_data.img = sitk.GetArrayFromImage(ct_sitk_img)\n                ct_data.origin = ct_sitk_img.GetOrigin()\n                ct_data.spacing = ct_sitk_img.GetSpacing()\n                \n                # 转换为HU单位\n                ct_data.convert_to_hu()\n                \n                return ct_data\n            \n            except Exception as e:\n                logger.error(f\"从DICOM加载失败: {e}\")\n                # 确保临时目录被清理\n                if temp_dir and os.path.exists(temp_dir):\n                    shutil.rmtree(temp_dir)\n                raise\n        \n        @classmethod\n        def from_mhd(cls, mhd_path):\n            \"\"\"从MHD文件加载CT数据\"\"\"\n            ct_data = cls()\n            \n            try:\n                logger.info(f\"从MHD加载: {mhd_path}\")\n                \n                # 读取MHD文件\n                ct_sitk_img = sitk.ReadImage(mhd_path)\n                \n                # 转换为numpy数组\n                ct_data.img = sitk.GetArrayFromImage(ct_sitk_img)\n                ct_data.origin = ct_sitk_img.GetOrigin()\n                ct_data.spacing = ct_sitk_img.GetSpacing()\n                \n                # 转换为HU单位 (假设MHD已经是HU单位)\n                # 如果不是，可以取消下面的注释\n                # ct_data.convert_to_hu()\n                \n                return ct_data\n            \n            except Exception as e:\n                logger.error(f\"从MHD加载失败: {e}\")\n                raise\n        \n        def resample_pixels(self, new_spacing=[1.0, 1.0, 1.0]):\n            \"\"\"将像素重采样到指定间距\"\"\"\n            if self.img is None:\n                logger.error(\"没有图像数据可重采样\")\n                return\n            \n            logger.info(f\"重采样像素，原始间距: {self.spacing}，目标间距: {new_spacing}\")\n            \n            # 计算调整后的大小\n            resize_factor = np.array(self.spacing) / np.array(new_spacing)\n            new_real_shape = self.img.shape * resize_factor\n            new_shape = np.round(new_real_shape).astype(np.int32)\n            \n            # 计算用于调整大小的实际调整因子\n            real_resize_factor = new_shape / self.img.shape\n            real_new_spacing = np.array(self.spacing) / real_resize_factor\n            \n            # 使用sitk进行重采样\n            sitk_img = sitk.GetImageFromArray(self.img)\n            sitk_img.SetSpacing(self.spacing)\n            \n            resample = sitk.ResampleImageFilter()\n            resample.SetInterpolator(sitk.sitkLinear)\n            resample.SetOutputSpacing(real_new_spacing)\n            resample.SetSize(new_shape.tolist())\n            resample.SetOutputDirection(sitk_img.GetDirection())\n            resample.SetOutputOrigin(sitk_img.GetOrigin())\n            \n            resampled_img = resample.Execute(sitk_img)\n            self.img = sitk.GetArrayFromImage(resampled_img)\n            self.spacing = real_new_spacing\n            \n            return self\n        \n        def convert_to_hu(self):\n            \"\"\"将图像转换为HU单位\"\"\"\n            if self.img is None:\n                logger.error(\"没有图像数据可转换\")\n                return\n            \n            if self.is_normalized:\n                logger.info(\"图像已经转换为HU单位\")\n                return\n            \n            logger.info(\"转换图像为HU单位\")\n            \n            # 对于DICOM，通常需要转换为HU单位\n            # 但这个实现假设数据已经是HU或类似单位\n            # 如果需要，这里可以添加特定的转换逻辑\n            \n            self.is_normalized = True\n            return self\n        \n        def filter_lung_img_mask(self, threshold=-320):\n            \"\"\"提取肺部区域图像和掩码\"\"\"\n            if self.img is None:\n                logger.error(\"没有图像数据可分割\")\n                return\n            \n            logger.info(\"分割肺部区域\")\n            \n            # 确保图像已转换为HU单位\n            if not self.is_normalized:\n                self.convert_to_hu()\n            \n            # 创建阈值掩码\n            threshold_image = np.copy(self.img)\n            threshold_image[threshold_image < threshold] = 1\n            threshold_image[threshold_image >= threshold] = 0\n            \n            # 获取与身体连接的区域\n            from scipy import ndimage as ndi\n            \n            # 填充身体外部的空气\n            mask = self.fill_body_mask(threshold_image)\n            \n            # 反转掩码以获取身体内的空气区域\n            lung_mask = np.logical_xor(threshold_image, mask)\n            \n            # 移除小连接区域\n            struct = np.ones((2, 2, 2), dtype=np.bool_)\n            lung_mask = ndi.binary_opening(lung_mask, structure=struct, iterations=2)\n            \n            labeled_mask, num_features = ndi.label(lung_mask)\n            \n            # 只保留最大的两个连接区域（肺）\n            if num_features > 2:\n                areas = [np.sum(labeled_mask == i) for i in range(1, num_features + 1)]\n                labels = np.argsort(areas)[-2:] + 1\n                \n                # 创建新的肺掩码\n                lung_mask = np.zeros_like(labeled_mask, dtype=bool)\n                for label in labels:\n                    lung_mask = lung_mask | (labeled_mask == label)\n            \n            # 保存结果\n            self.lung_mask = lung_mask\n            self.lung_seg_img = self.img * lung_mask\n            self.has_lung_seg = True\n            \n            return self\n        \n        def fill_body_mask(self, threshold_image):\n            \"\"\"填充身体外部空气区域\"\"\"\n            # 创建边界种子\n            mask = np.zeros_like(threshold_image, dtype=bool)\n            mask[0, :, :] = True\n            mask[-1, :, :] = True\n            mask[:, 0, :] = True\n            mask[:, -1, :] = True\n            mask[:, :, 0] = True\n            mask[:, :, -1] = True\n            \n            # 找到与边界连接的所有区域（即外部空气）\n            from scipy import ndimage as ndi\n            mask = ndi.binary_dilation(mask, structure=np.ones((3, 3, 3)), iterations=1)\n            mask = np.logical_and(mask, threshold_image > 0)\n            mask = ndi.binary_fill_holes(mask)\n            \n            return mask\n\n\ndef extract_lung_from_image(sitk_image):\n    \"\"\"\n    从SimpleITK图像中提取肺部区域\n    :param sitk_image: SimpleITK CT图像\n    :return: 肺部分割掩码 (numpy数组)\n    \"\"\"\n    # 转换为numpy数组\n    ct_array = sitk.GetArrayFromImage(sitk_image)\n    \n    # 获取图像信息\n    spacing = sitk_image.GetSpacing()\n    origin = sitk_image.GetOrigin()\n    direction = sitk_image.GetDirection()\n    \n    # 进行肺部分割\n    lung_mask = segment_lung(ct_array)\n    \n    return lung_mask\n\ndef segment_lung(ct_array):\n    \"\"\"\n    简单的肺部分割算法\n    :param ct_array: CT图像数组 [z, y, x]\n    :return: 肺部掩码\n    \"\"\"\n    # 1. 阈值处理 - 肺部通常是空气(-1000HU)到组织(-500HU)之间\n    binary_image = np.logical_and(ct_array > -1000, ct_array < -500)\n    \n    # 2. 对每个切片进行处理\n    result_mask = np.zeros_like(binary_image, dtype=bool)\n    \n    for i in range(binary_image.shape[0]):\n        # 获取当前切片\n        slice_img = binary_image[i].copy()\n        \n        # 填充边界以确保背景是连通的\n        slice_img[0,:] = 1\n        slice_img[-1,:] = 1\n        slice_img[:,0] = 1\n        slice_img[:,-1] = 1\n        \n        # 标记切片中的连通区域\n        labeled, num_labels = ndimage.label(slice_img)\n        \n        # 按大小排序区域\n        regions = ndimage.find_objects(labeled)\n        region_sizes = [(i+1, (region[0].stop - region[0].start) * (region[1].stop - region[1].start)) \n                        for i, region in enumerate(regions)]\n        region_sizes.sort(key=lambda x: x[1], reverse=True)\n        \n        # 第一个最大的区域通常是背景或身体，我们需要肺部区域\n        lung_mask = np.zeros_like(slice_img, dtype=bool)\n        \n        # 找出前5个最大区域（通常背景为最大，可能左右肺为2、3大区域）\n        for j in range(min(5, len(region_sizes))):\n            # 排除最大区域（通常是背景或身体）\n            if j == 0:\n                continue\n                \n            region_idx = region_sizes[j][0]\n            # 添加这个区域到肺部掩码\n            lung_mask[labeled == region_idx] = True\n        \n        # 形态学操作以填充肺部内的孔洞（例如血管）\n        lung_mask = ndimage.binary_closing(lung_mask, structure=np.ones((5,5)))\n        lung_mask = ndimage.binary_fill_holes(lung_mask)\n        \n        # 保存到结果\n        result_mask[i] = lung_mask\n    \n    # 确保切片之间的连续性\n    result_mask = ndimage.binary_closing(result_mask, structure=np.ones((3,3,3)))\n    \n    return result_mask.astype(np.uint8)\n\ndef extract_lung_from_file(file_path):\n    \"\"\"\n    从文件中提取肺部区域\n    支持MHD和DICOM格式\n    :param file_path: 文件路径\n    :return: 肺部数据字典\n    \"\"\"\n    # 根据文件类型加载图像\n    if file_path.lower().endswith('.mhd'):\n        # 加载MHD文件\n        ct_image = sitk.ReadImage(file_path)\n    elif file_path.lower().endswith(('.dcm', '.dicom')):\n        # 加载单个DICOM文件\n        ct_image = sitk.ReadImage(file_path)\n    elif os.path.isdir(file_path):\n        # 加载DICOM系列\n        reader = sitk.ImageSeriesReader()\n        dicom_names = reader.GetGDCMSeriesFileNames(file_path)\n        reader.SetFileNames(dicom_names)\n        ct_image = reader.Execute()\n    else:\n        raise ValueError(f\"不支持的文件类型: {file_path}\")\n    \n    # 提取肺部\n    lung_mask = extract_lung_from_image(ct_image)\n    \n    # 应用掩码到原始CT\n    ct_array = sitk.GetArrayFromImage(ct_image)\n    lung_data = ct_array.copy()\n    lung_data[~lung_mask.astype(bool)] = -2000  # 设置非肺部区域为-2000HU\n    \n    # 构建返回数据\n    return {\n        'lung_data': lung_data,\n        'lung_mask': lung_mask,\n        'original_ct': ct_array,\n        'spacing': ct_image.GetSpacing(),\n        'origin': ct_image.GetOrigin(),\n        'direction': ct_image.GetDirection()\n    }\n\ndef prepare_data_for_3d_rendering(lung_data):\n    \"\"\"\n    将肺部数据转换为适合3D渲染的格式\n    :param lung_data: 肺部数据字典\n    :return: 渲染数据\n    \"\"\"\n    # 获取肺部掩码的轮廓\n    mask = lung_data['lung_mask']\n    \n    # 降采样以减少数据量\n    downsampled_mask = mask[::4, ::4, ::4]\n    \n    # 获取表面点（简单方法：获取非零点坐标）\n    points = np.argwhere(downsampled_mask > 0).tolist()\n    \n    # 计算边界框\n    if len(points) > 0:\n        points_array = np.array(points)\n        min_bounds = points_array.min(axis=0).tolist()\n        max_bounds = points_array.max(axis=0).tolist()\n    else:\n        min_bounds = [0, 0, 0]\n        max_bounds = list(downsampled_mask.shape)\n    \n    # 构建渲染数据\n    render_data = {\n        'points': points,\n        'dimensions': downsampled_mask.shape,\n        'min_bounds': min_bounds,\n        'max_bounds': max_bounds\n    }\n    \n    return render_data "
  },
  {
    "path": "deploy/frontend/css/style.css",
    "content": "/* 主要样式表 */\n\n/* 重置和全局样式 */\n* {\n    margin: 0;\n    padding: 0;\n    box-sizing: border-box;\n    font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;\n}\n\nbody {\n    background-color: #f5f7fa;\n    color: #333;\n    line-height: 1.6;\n}\n\n.container {\n    display: flex;\n    min-height: 100vh;\n    width: 100%;\n}\n\n/* 侧边栏样式 */\n.sidebar {\n    width: 300px;\n    background-color: #2c3e50;\n    color: #ecf0f1;\n    padding: 20px;\n    overflow-y: auto;\n    box-shadow: 2px 0 5px rgba(0, 0, 0, 0.1);\n    display: flex;\n    flex-direction: column;\n}\n\n.logo {\n    text-align: center;\n    padding: 10px 0 20px;\n    border-bottom: 1px solid #34495e;\n}\n\n.logo h1 {\n    font-size: 1.4rem;\n    color: #3498db;\n}\n\n/* 步骤导航 */\n.steps {\n    margin: 20px 0;\n}\n\n.step {\n    display: flex;\n    padding: 15px 10px;\n    margin-bottom: 10px;\n    border-radius: 5px;\n    background-color: #34495e;\n    opacity: 0.7;\n    transition: all 0.3s ease;\n}\n\n.step.active {\n    background-color: #3498db;\n    opacity: 1;\n    box-shadow: 0 2px 5px rgba(0, 0, 0, 0.2);\n}\n\n.step.complete {\n    background-color: #27ae60;\n    opacity: 1;\n}\n\n.step-icon {\n    width: 40px;\n    height: 40px;\n    border-radius: 50%;\n    background-color: rgba(255, 255, 255, 0.2);\n    display: flex;\n    align-items: center;\n    justify-content: center;\n    margin-right: 15px;\n}\n\n.step-icon i {\n    font-size: 1.2rem;\n}\n\n.step-content h3 {\n    font-size: 1rem;\n    margin-bottom: 5px;\n}\n\n.step-content p {\n    font-size: 0.8rem;\n    opacity: 0.8;\n}\n\n/* 上传区域 */\n.upload-area {\n    background-color: #34495e;\n    padding: 15px;\n    border-radius: 5px;\n    margin-bottom: 20px;\n}\n\n.upload-area h3 {\n    margin-bottom: 10px;\n    font-size: 1rem;\n}\n\n.upload-area p {\n    font-size: 0.8rem;\n    opacity: 0.8;\n    margin-bottom: 15px;\n}\n\n.file-input-container {\n    margin-bottom: 15px;\n}\n\ninput[type=\"file\"] {\n    display: none;\n}\n\n.select-file-btn {\n    background-color: #3498db;\n    color: white;\n    border: none;\n    padding: 10px 15px;\n    border-radius: 5px;\n    cursor: pointer;\n    width: 100%;\n    margin-bottom: 10px;\n    font-size: 0.9rem;\n    transition: background-color 0.3s;\n}\n\n.select-file-btn:hover {\n    background-color: #2980b9;\n}\n\n#file-info {\n    font-size: 0.8rem;\n    padding: 5px;\n    background-color: rgba(255, 255, 255, 0.1);\n    border-radius: 3px;\n    margin-top: 5px;\n    word-break: break-all;\n}\n\n.patient-info-input {\n    margin-bottom: 15px;\n}\n\n.patient-info-input label {\n    display: block;\n    font-size: 0.8rem;\n    margin-bottom: 5px;\n}\n\n.patient-info-input input {\n    width: 100%;\n    padding: 8px;\n    border-radius: 3px;\n    border: 1px solid #546e7a;\n    background-color: rgba(255, 255, 255, 0.1);\n    color: white;\n}\n\n.upload-btn {\n    background-color: #e74c3c;\n    color: white;\n    border: none;\n    padding: 10px 15px;\n    border-radius: 5px;\n    cursor: pointer;\n    width: 100%;\n    font-size: 0.9rem;\n    transition: background-color 0.3s;\n}\n\n.upload-btn:hover {\n    background-color: #c0392b;\n}\n\n.upload-btn:disabled {\n    background-color: #7f8c8d;\n    cursor: not-allowed;\n}\n\n/* 进度区域 */\n.progress-area {\n    background-color: #34495e;\n    padding: 15px;\n    border-radius: 5px;\n    margin-bottom: 20px;\n}\n\n.progress-area h3 {\n    margin-bottom: 10px;\n    font-size: 1rem;\n}\n\n.progress-container {\n    height: 20px;\n    background-color: rgba(255, 255, 255, 0.1);\n    border-radius: 10px;\n    overflow: hidden;\n    position: relative;\n    margin-bottom: 10px;\n}\n\n.progress-bar {\n    height: 100%;\n    background-color: #2ecc71;\n    width: 0%;\n    transition: width 0.3s ease;\n    border-radius: 10px;\n}\n\n.progress-text {\n    position: absolute;\n    top: 0;\n    left: 0;\n    right: 0;\n    bottom: 0;\n    display: flex;\n    align-items: center;\n    justify-content: center;\n    font-size: 0.8rem;\n    color: white;\n}\n\n.progress-message {\n    font-size: 0.8rem;\n    margin: 10px 0;\n    min-height: 40px;\n}\n\n.action-buttons {\n    display: flex;\n    justify-content: center;\n    margin-top: 15px;\n}\n\n.btn {\n    background-color: #3498db;\n    color: white;\n    border: none;\n    padding: 8px 15px;\n    border-radius: 5px;\n    cursor: pointer;\n    font-size: 0.9rem;\n    transition: background-color 0.3s;\n}\n\n.btn:hover {\n    background-color: #2980b9;\n}\n\n.btn:disabled {\n    background-color: #7f8c8d;\n    cursor: not-allowed;\n}\n\n/* 患者信息区域 */\n.patient-info-area {\n    margin-top: auto;\n    padding: 10px;\n    font-size: 0.8rem;\n    background-color: rgba(255, 255, 255, 0.1);\n    border-radius: 5px;\n}\n\n/* 主内容区域 */\n.main-content {\n    flex: 1;\n    display: flex;\n    flex-direction: row;\n    overflow: hidden;\n    position: relative;\n}\n\n/* 查看器容器 */\n.viewer-container {\n    flex: 0.7; /* 占据主区域的70% */\n    height: 100vh;\n    padding: 15px;\n    overflow: hidden;\n    position: relative;\n}\n\n/* CT切片查看器 */\n#ct-viewer {\n    background-color: #f5f5f5;\n    width: 100%;\n    height: 100%;\n    min-height: 500px; /* 确保有足够高度显示切片 */\n    position: relative;\n    border-radius: 5px;\n    overflow: auto; /* 允许内容溢出时滚动 */\n    box-shadow: 0 2px 8px rgba(0,0,0,0.15);\n    display: flex; /* 使用flex布局 */\n    flex-direction: column; /* 垂直排列内容 */\n    justify-content: flex-start; /* 从上往下排列 */\n    align-items: center; /* 水平居中 */\n    padding: 10px; /* 增加内边距 */\n}\n\n/* 查看器控制区 */\n.viewer-controls {\n    position: absolute;\n    bottom: 25px;\n    right: 25px;\n    z-index: 10;\n}\n\n/* 结果区域 */\n.results-panel {\n    flex: 0.25; /* 占据主区域的25% */\n    height: 100vh;\n    padding: 20px;\n    background-color: white;\n    box-shadow: -2px 0 5px rgba(0, 0, 0, 0.1);\n    display: flex;\n    flex-direction: column;\n    overflow-y: auto;\n}\n\n.results-header {\n    display: flex;\n    justify-content: space-between;\n    align-items: center;\n    margin-bottom: 20px;\n    padding-bottom: 10px;\n    border-bottom: 1px solid #e0e0e0;\n}\n\n.results-header h2 {\n    font-size: 1.2rem;\n    color: #2c3e50;\n}\n\n.summary {\n    font-size: 0.9rem;\n    color: #7f8c8d;\n}\n\n/* 结果布局 */\n.results-layout {\n    display: flex;\n    flex-direction: column;\n    gap: 20px;\n    height: 100%;\n    min-height: 200px;\n}\n\n/* 结节列表区域 */\n.nodule-list-section {\n    flex: 1;\n    max-height: 300px;\n}\n\n.nodule-list-section h3 {\n    margin-bottom: 10px;\n    color: #2c3e50;\n    font-size: 1rem;\n}\n\n#nodule-list {\n    margin-top: 10px;\n    overflow-y: auto;\n    max-height: 250px;\n    border: 1px solid #eee;\n    border-radius: 5px;\n}\n\n.nodule-item {\n    background-color: #f5f5f5;\n    border-radius: 5px;\n    padding: 10px;\n    margin-bottom: 10px;\n    cursor: pointer;\n    border-left: 4px solid #ccc;\n    transition: all 0.2s ease;\n}\n\n.nodule-item:hover {\n    background-color: #e9e9e9;\n    transform: translateX(2px);\n}\n\n.nodule-item.selected {\n    background-color: #e1f5fe;\n    border-left-color: #2196F3;\n}\n\n/* 结节详情区域 */\n.nodule-detail-section {\n    flex: 2;\n}\n\n#nodule-detail-area {\n    background-color: #f5f5f5;\n    border-radius: 5px;\n    padding: 15px;\n    height: 100%;\n    box-shadow: 0 1px 3px rgba(0,0,0,0.1);\n}\n\n.no-selection {\n    display: flex;\n    align-items: center;\n    justify-content: center;\n    height: 100%;\n    min-height: 150px;\n    color: #999;\n    font-style: italic;\n}\n\n/* 没有结节的提示 */\n.no-nodules {\n    text-align: center;\n    padding: 20px;\n    color: #666;\n    font-style: italic;\n    background-color: #f9f9f9;\n    border-radius: 5px;\n}\n\n/* 预览区域 */\n.nodule-preview {\n    margin-top: 20px;\n}\n\n.nodule-preview h4 {\n    font-size: 14px;\n    margin-bottom: 10px;\n}\n\n/* 预览视图样式 */\n.preview-views {\n    display: flex;\n    flex-wrap: wrap;\n    justify-content: space-between;\n    margin-top: 15px;\n}\n\n.preview-view {\n    width: 31%;\n    margin-bottom: 15px;\n    background-color: #f9f9f9;\n    border-radius: 5px;\n    overflow: hidden;\n    box-shadow: 0 1px 3px rgba(0,0,0,0.1);\n}\n\n.preview-view h5 {\n    margin: 0;\n    padding: 8px;\n    font-size: 14px;\n    background-color: #eaeaea;\n    text-align: center;\n}\n\n.preview-img {\n    height: 150px;\n    display: flex;\n    align-items: center;\n    justify-content: center;\n    position: relative;\n    overflow: hidden;\n}\n\n.preview-img img {\n    max-width: 100%;\n    max-height: 100%;\n    display: block;\n    border: none;\n}\n\n/* 加载中状态显示 */\n.loading-spinner.small {\n    padding: 10px;\n}\n\n.loading-spinner.small .spinner {\n    width: 25px;\n    height: 25px;\n    border-width: 3px;\n}\n\n.loading-spinner.small p {\n    font-size: 12px;\n    margin-top: 5px;\n}\n\n/* 错误消息样式 */\n.error-message.small {\n    font-size: 12px;\n    padding: 10px;\n    margin: 5px;\n}\n\n@media screen and (max-width: 768px) {\n    .preview-view {\n        width: 100%;\n        margin-bottom: 10px;\n    }\n}\n\n/* 消息提示 */\n.message-container {\n    position: fixed;\n    top: 20px;\n    right: 20px;\n    width: 300px;\n    z-index: 1000;\n}\n\n.message {\n    padding: 15px;\n    margin-bottom: 10px;\n    border-radius: 5px;\n    box-shadow: 0 2px 10px rgba(0, 0, 0, 0.2);\n    animation: slideIn 0.3s ease;\n}\n\n.message.info {\n    background-color: #3498db;\n    color: white;\n}\n\n.message.success {\n    background-color: #2ecc71;\n    color: white;\n}\n\n.message.warning {\n    background-color: #f39c12;\n    color: white;\n}\n\n.message.error {\n    background-color: #e74c3c;\n    color: white;\n}\n\n.message.fade-out {\n    animation: fadeOut 0.5s ease forwards;\n}\n\n@keyframes slideIn {\n    from {\n        transform: translateX(100%);\n        opacity: 0;\n    }\n    to {\n        transform: translateX(0);\n        opacity: 1;\n    }\n}\n\n@keyframes fadeOut {\n    from {\n        opacity: 1;\n    }\n    to {\n        opacity: 0;\n    }\n}\n\n/* 响应式调整 */\n@media (max-width: 768px) {\n    .container {\n        flex-direction: column;\n    }\n    \n    .sidebar {\n        width: 100%;\n        max-height: 300px;\n    }\n    \n    .main-content {\n        height: calc(100vh - 300px);\n    }\n    \n    .ct-viewer-container {\n        height: 40%;\n    }\n}\n\n/* 结节项样式 */\n.nodule-header {\n    display: flex;\n    justify-content: space-between;\n    align-items: center;\n    margin-bottom: 8px;\n}\n\n.nodule-name {\n    font-weight: 600;\n    font-size: 14px;\n}\n\n.nodule-probability {\n    padding: 3px 6px;\n    border-radius: 10px;\n    font-size: 12px;\n    font-weight: bold;\n}\n\n.nodule-probability.high {\n    background-color: #ffebee;\n    color: #d32f2f;\n}\n\n.nodule-probability.medium {\n    background-color: #fff8e1;\n    color: #ff8f00;\n}\n\n.nodule-probability.low {\n    background-color: #e8f5e9;\n    color: #388e3c;\n}\n\n.nodule-details {\n    font-size: 12px;\n    color: #666;\n}\n\n.nodule-detail {\n    margin-bottom: 4px;\n}\n\n.detail-label {\n    color: #777;\n    margin-right: 5px;\n}\n\n/* 详细信息项样式 */\n.detail-item {\n    margin-bottom: 8px;\n    display: flex;\n}\n\n.detail-label {\n    width: 90px;\n    color: #666;\n}\n\n.detail-value {\n    flex: 1;\n    font-weight: 500;\n}\n\n.detail-value.high {\n    color: #d32f2f;\n}\n\n.detail-value.medium {\n    color: #ff8f00;\n}\n\n.detail-value.low {\n    color: #388e3c;\n}\n\n/* 加载状态 */\n.loading-spinner {\n    display: flex;\n    flex-direction: column;\n    align-items: center;\n    justify-content: center;\n    padding: 30px;\n}\n\n.spinner {\n    width: 30px;\n    height: 30px;\n    border: 3px solid #f3f3f3;\n    border-top: 3px solid #2196F3;\n    border-radius: 50%;\n    animation: spin 1s linear infinite;\n    margin-bottom: 10px;\n}\n\n@keyframes spin {\n    0% { transform: rotate(0deg); }\n    100% { transform: rotate(360deg); }\n}\n\n.error-message {\n    color: #d32f2f;\n    padding: 15px;\n    text-align: center;\n    background-color: #ffebee;\n    border-radius: 4px;\n}\n\n/* CT切片视图样式 */\n.slice-views-container {\n    display: flex;\n    flex-wrap: wrap;\n    justify-content: space-between;\n    width: 100%;\n    height: 100%;\n    min-height: 450px; /* 确保有足够高度显示切片 */\n    padding: 10px;\n    box-sizing: border-box;\n}\n\n.slice-view-container {\n    width: calc(33.33% - 10px);\n    height: 450px; /* 固定高度 */\n    min-height: 300px;\n    display: flex;\n    flex-direction: column;\n    background-color: #222;\n    border-radius: 5px;\n    overflow: hidden;\n    box-shadow: 0 2px 8px rgba(0, 0, 0, 0.2);\n    margin-bottom: 15px;\n}\n\n.slice-view-header {\n    background-color: #333;\n    color: white;\n    padding: 8px 12px;\n    font-weight: bold;\n    font-size: 14px;\n    text-align: center;\n}\n\n.slice-view {\n    flex: 1;\n    display: flex;\n    align-items: center;\n    justify-content: center;\n    overflow: hidden;\n    background-color: #000;\n    position: relative;\n    min-height: 350px; /* 确保切片视图区域有足够高度 */\n}\n\n.slice-view img {\n    max-width: 100%;\n    max-height: 100%;\n    object-fit: contain;\n    display: block;\n    border: 1px solid #444; /* 添加边框以便更容易看到图像边界 */\n}\n\n.slice-view-controls {\n    display: flex;\n    align-items: center;\n    justify-content: center;\n    padding: 8px;\n    background-color: #333;\n}\n\n.slice-btn {\n    background-color: #444;\n    color: white;\n    border: none;\n    border-radius: 3px;\n    width: 30px;\n    height: 30px;\n    display: flex;\n    align-items: center;\n    justify-content: center;\n    cursor: pointer;\n    transition: background-color 0.2s;\n}\n\n.slice-btn:hover {\n    background-color: #666;\n}\n\n.slice-btn:active {\n    background-color: #555;\n}\n\n.slice-index {\n    color: white;\n    margin: 0 15px;\n    font-size: 14px;\n    min-width: 60px;\n    text-align: center;\n}\n\n/* 响应式调整 */\n@media (max-width: 1200px) {\n    .slice-view-container {\n        width: calc(50% - 10px);\n    }\n}\n\n@media (max-width: 768px) {\n    .slice-view-container {\n        width: 100%;\n    }\n} "
  },
  {
    "path": "deploy/frontend/index.html",
    "content": "<!DOCTYPE html>\n<html lang=\"zh-CN\">\n<head>\n    <meta charset=\"UTF-8\">\n    <meta name=\"viewport\" content=\"width=device-width, initial-scale=1.0\">\n    <title>CT 肺结节检测系统</title>\n    <link rel=\"stylesheet\" href=\"css/style.css\">\n    <link rel=\"stylesheet\" href=\"https://cdnjs.cloudflare.com/ajax/libs/font-awesome/5.15.4/css/all.min.css\">\n</head>\n<body>\n    <div class=\"container\">\n        <!-- 侧边栏 - 处理步骤和进度 -->\n        <aside class=\"sidebar\">\n            <div class=\"logo\">\n                <h1>CT 肺结节检测</h1>\n            </div>\n            \n            <!-- 步骤导航 -->\n            <div class=\"steps\">\n                <div class=\"step active\" id=\"step-upload\">\n                    <div class=\"step-icon\">\n                        <i class=\"fas fa-upload\"></i>\n                    </div>\n                    <div class=\"step-content\">\n                        <h3>1. 上传数据</h3>\n                        <p>支持MHD、DICOM、NIfTI格式</p>\n                    </div>\n                </div>\n                \n                <div class=\"step\" id=\"step-detect\">\n                    <div class=\"step-icon\">\n                        <i class=\"fas fa-search\"></i>\n                    </div>\n                    <div class=\"step-content\">\n                        <h3>2. 检测分析</h3>\n                        <p>AI模型分析结节</p>\n                    </div>\n                </div>\n                \n                <div class=\"step\" id=\"step-visualize\">\n                    <div class=\"step-icon\">\n                        <i class=\"fas fa-eye\"></i>\n                    </div>\n                    <div class=\"step-content\">\n                        <h3>3. 可视化结果</h3>\n                        <p>查看CT切片与结节</p>\n                    </div>\n                </div>\n            </div>\n\n            <!-- 上传区域 -->\n            <div class=\"upload-area\" id=\"upload-area\">\n                <h3>上传CT文件</h3>\n                <p>支持的格式: MHD, DICOM, NIfTI, 压缩包</p>\n\n                <div class=\"file-input-container\">\n                    <input type=\"file\" id=\"ct-file\" accept=\".mhd,.raw,.dcm,.dicom,.nii,.nii.gz,.zip\">\n                    <button class=\"select-file-btn\" id=\"select-file-btn\">选择CT文件</button>\n                    <div id=\"file-info\">请选择CT文件</div>\n                </div>\n\n                <div class=\"patient-info-input\">\n                    <label for=\"patient-id\">患者ID (可选):</label>\n                    <input type=\"text\" id=\"patient-id\" placeholder=\"输入患者ID\">\n                </div>\n\n                <button class=\"upload-btn\" id=\"upload-btn\" disabled>上传数据</button>\n            </div>\n\n            <!-- 进度区域 -->\n            <div class=\"progress-area\" id=\"progress-area\" style=\"display: none;\">\n                <h3>处理进度</h3>\n                <div class=\"progress-container\">\n                    <div class=\"progress-bar\" id=\"progress-bar\"></div>\n                    <div class=\"progress-text\" id=\"progress-text\">0%</div>\n                </div>\n                <div class=\"progress-message\" id=\"progress-message\">正在准备处理...</div>\n                \n                <div class=\"action-buttons\">\n                    <button class=\"btn\" id=\"start-detection-btn\">开始检测</button>\n                </div>\n            </div>\n            \n            <!-- 患者信息区域 -->\n            <div class=\"patient-info-area\">\n                <div id=\"patient-info\">患者ID: 未指定</div>\n            </div>\n        </aside>\n        \n        <!-- 主内容区 -->\n        <main class=\"main-content\">\n            <!-- CT切片查看器 -->\n            <div class=\"viewer-container\">\n                <div id=\"ct-viewer\"></div>\n                <div class=\"viewer-controls\">\n                    <button class=\"btn\" id=\"reset-view-btn\">\n                        <i class=\"fas fa-sync-alt\"></i> 重置视图\n                    </button>\n                </div>\n            </div>\n            \n            <!-- 结果区域 -->\n            <div class=\"results-panel\">\n                <div class=\"results-header\">\n                    <h2>检测结果</h2>\n                    <div class=\"summary\">共 <span id=\"nodules-count\">0</span> 个结节</div>\n                </div>\n                \n                <div class=\"results-layout\">\n                    <!-- 结节列表 -->\n                    <div class=\"nodule-list-section\">\n                        <h3>结节列表</h3>\n                        <div id=\"nodule-list\">\n                            <!-- 结节列表将通过JS动态生成 -->\n                            <div class=\"no-nodules\">未检测到结节</div>\n                        </div>\n                    </div>\n                    \n                    <!-- 结节详情 -->\n                    <div class=\"nodule-detail-section\">\n                        <div id=\"nodule-detail-area\">\n                            <!-- 结节详情将通过JS动态生成 -->\n                            <div class=\"no-selection\">请从上方选择结节</div>\n                        </div>\n                    </div>\n                </div>\n            </div>\n        </main>\n    </div>\n    \n    <!-- 消息提示容器 -->\n    <div class=\"message-container\" id=\"message-container\"></div>\n    \n    <!-- 脚本 -->\n    <script src=\"js/main.js\"></script>\n</body>\n</html> "
  },
  {
    "path": "deploy/frontend/js/main.js",
    "content": "/**\n * CT 肺结节检测系统 主JS文件\n */\n\n// 全局变量\nlet currentSessionId = null;\nlet currentSliceIndex = 0;\nlet maxSliceIndex = 0;\nlet lungSegmentationLoaded = false;\nlet progressInterval = null;\n\n// 页面加载完成后初始化\ndocument.addEventListener('DOMContentLoaded', function() {\n    console.log('CT肺结节检测系统初始化');\n    \n    // 初始化事件监听器\n    initializeEventListeners();\n    // 初始化步骤UI\n    initializeSteps();\n    // 设置文件上传\n    setupFileUpload();\n    // 检查是否有会话ID保存在sessionStorage中\n    const savedSessionId = sessionStorage.getItem('currentSessionId');\n});\n\n// 初始化步骤显示\nfunction initializeSteps() {\n    // 初始状态下激活上传步骤\n    updateUIState('initial');\n}\n\n// 设置文件上传\nfunction setupFileUpload() {\n    const fileInput = document.getElementById('ct-file');\n    const selectFileBtn = document.getElementById('select-file-btn');\n    const uploadBtn = document.getElementById('upload-btn');\n    const fileInfo = document.getElementById('file-info');\n    \n    // 选择文件按钮点击事件\n    selectFileBtn.addEventListener('click', function() {\n        fileInput.click();\n    });\n    \n    // 文件选择变化事件\n    fileInput.addEventListener('change', function() {\n        if (this.files.length > 0) {\n            const file = this.files[0];\n            \n            // 检查文件类型\n            const fileExt = getFileExtension(file.name).toLowerCase();\n            if (!['zip', 'mhd', 'dcm', 'dicom', 'nii', 'nii.gz'].includes(fileExt)) {\n                showMessage('请上传支持的格式: MHD, DICOM, NIfTI 或 ZIP文件', 'error');\n                fileInfo.textContent = '请选择支持的CT文件格式';\n                uploadBtn.disabled = true;\n                return;\n            }\n            \n            fileInfo.textContent = `已选择: ${file.name} (${formatFileSize(file.size)})`;\n            uploadBtn.disabled = false;\n        } else {\n            fileInfo.textContent = '请选择CT文件';\n            uploadBtn.disabled = true;\n        }\n    });\n    \n    // 上传按钮点击事件\n    uploadBtn.addEventListener('click', function() {\n        if (fileInput.files.length === 0) {\n            showMessage('请先选择文件', 'error');\n            return;\n        }\n        \n        // 上传文件\n        uploadFile(fileInput.files[0]);\n    });\n}\n\n// 初始化事件监听器\nfunction initializeEventListeners() {\n    // 重置视图按钮\n    document.getElementById('reset-view-btn').addEventListener('click', function() {\n        if (maxSliceIndex > 0) {\n            // 将切片重置到中间位置\n            currentSliceIndex = Math.floor(maxSliceIndex / 2);\n            loadSlice(currentSliceIndex);\n            showMessage('已重置视图至中间切片', 'info');\n        } else {\n            showMessage('无法重置视图，未加载数据', 'warning');\n        }\n    });\n\n    // 开始检测按钮\n    document.getElementById('start-detection-btn').addEventListener('click', startDetection);\n    // 结节列表点击事件委托\n    document.getElementById('nodule-list').addEventListener('click', function(e) {\n        if (e.target.closest('.nodule-item')) {\n            const noduleItem = e.target.closest('.nodule-item');\n            const noduleId = noduleItem.dataset.id;\n            // 移除其他项目的选中状态\n            document.querySelectorAll('.nodule-item').forEach(item => {\n                item.classList.remove('selected');\n            });\n            \n            // 添加选中状态\n            noduleItem.classList.add('selected');\n            \n            // 加载结节详情\n            console.log(`选中结节 ${noduleId}`);\n            loadNoduleDetails(noduleId);\n        }\n    });\n}\n\n// 获取文件扩展名\nfunction getFileExtension(filename) {\n    return filename.split('.').pop();\n}\n\n// 格式化文件大小\nfunction formatFileSize(bytes) {\n    if (bytes === 0) return '0 Bytes';\n    const k = 1024;\n    const sizes = ['Bytes', 'KB', 'MB', 'GB'];\n    const i = Math.floor(Math.log(bytes) / Math.log(k));\n    return parseFloat((bytes / Math.pow(k, i)).toFixed(2)) + ' ' + sizes[i];\n}\n\n// 上传文件\nfunction uploadFile(file) {\n    console.log('开始上传文件:', file.name);\n    \n    // 显示进度区域\n    document.getElementById('upload-area').style.display = 'none';\n    document.getElementById('progress-area').style.display = 'block';\n    \n    // 更新进度条初始状态\n    updateProgress(10, '正在上传文件...');\n    \n    // 创建FormData对象\n    const formData = new FormData();\n    formData.append('file', file);\n    \n    // 添加患者ID (如果有)\n    const patientId = document.getElementById('patient-id').value;\n    if (patientId) {\n        formData.append('patient_id', patientId);\n        document.getElementById('patient-info').textContent = `患者ID: ${patientId}`;\n    }\n    \n    // 发送上传请求\n    fetch('/api/upload', {\n        method: 'POST',\n        body: formData\n    })\n    .then(response => {\n        console.log('服务器响应状态:', response.status);\n        if (!response.ok) {\n            throw new Error(`网络响应错误: ${response.status}`);\n        }\n        return response.json();\n    })\n    .then(data => {\n        console.log('上传响应:', data);\n        if (data.success) {\n            // 保存会话ID\n            currentSessionId = data.session_id;\n            console.log('保存会话ID:', currentSessionId);\n            \n            // 同时保存到sessionStorage\n            sessionStorage.setItem('currentSessionId', currentSessionId);\n            \n            // 完成上传步骤\n            updateUIState('uploaded');\n            \n            // 更新进度\n            updateProgress(100, '文件上传成功');\n            \n            // 显示文件类型信息\n            let typeInfo = data.file_type ? `检测到${data.file_type}格式CT数据` : '文件已上传';\n            \n            // 根据是否自动检测显示不同消息\n            if (data.auto_detect) {\n                document.getElementById('progress-message').textContent = `${typeInfo}，正在自动开始检测...`;\n                document.getElementById('start-detection-btn').disabled = true;\n                \n                // 开始轮询检测进度\n                startProgressPolling();\n            } else {\n                document.getElementById('progress-message').textContent = `${typeInfo}，可以开始检测`;\n                document.getElementById('start-detection-btn').disabled = false;\n            }\n            \n            // 显示成功消息\n            showMessage('文件上传成功', 'success');\n        } else {\n            throw new Error(data.error || '上传失败');\n        }\n    })\n    .catch(error => {\n        console.error('文件上传错误:', error);\n        updateUIState('initial');\n        document.getElementById('progress-message').textContent = `上传失败: ${error.message}`;\n        showMessage(`上传失败: ${error.message}`, 'error');\n        \n        // 重新显示上传区域\n        setTimeout(() => {\n            document.getElementById('progress-area').style.display = 'none';\n            document.getElementById('upload-area').style.display = 'block';\n        }, 3000);\n    });\n}\n\n// 开始检测\nasync function startDetection() {\n        if (!currentSessionId) {\n            showMessage('请先上传CT文件', 'error');\n            return;\n        }\n        \n    try {\n        // 更新UI状态\n        updateUIState('detecting');\n        // 清除之前的数据\n        resetResults();\n        // 显示消息\n        showMessage('正在开始检测...', 'info');\n        // 发送检测请求\n        const response = await fetch('/api/detect', {\n            method: 'POST',\n            headers: {\n                'Content-Type': 'application/json'\n            },\n            body: JSON.stringify({\n                session_id: currentSessionId\n            })\n        });\n        \n        const data = await response.json();\n        \n        if (!data.success) {\n            throw new Error(data.error || '启动检测失败');\n        }\n        \n        // 显示成功消息\n        showMessage('检测已启动', 'success');\n        \n        // 开始轮询检测进度\n        startProgressPolling();\n        \n    } catch (error) {\n        console.error('启动检测时出错:', error);\n        showMessage(`启动检测失败: ${error.message}`, 'error');\n        updateUIState('uploaded');\n    }\n}\n\n// 开始轮询检测进度\nfunction startProgressPolling() {\n    // 清除之前的轮询\n    if (progressInterval) {\n        clearInterval(progressInterval);\n    }\n    \n    // 初始化进度条\n    updateProgress(0, '正在初始化检测...');\n    \n    // 显示进度区域\n    document.getElementById('progress-area').style.display = 'block';\n    \n    // 重置肺部分割加载状态\n    lungSegmentationLoaded = false;\n    \n    // 开始轮询\n    progressInterval = setInterval(async function() {\n        try {\n            if (!currentSessionId) {\n                clearInterval(progressInterval);\n                return;\n            }\n            \n            const response = await fetch(`/api/progress/${currentSessionId}`);\n            const data = await response.json();\n            \n            if (!data.success) {\n                throw new Error(data.error || '获取进度失败');\n            }\n            \n            // 更新进度条\n            updateProgress(data.progress, data.message);\n            \n            // 当进度达到40%时，可以开始尝试加载肺部切片数据\n            if (data.progress >= 40 && !lungSegmentationLoaded) {\n                loadLungSegmentation();\n            }\n            \n            // 如果状态为已完成或错误，停止轮询\n            if (data.status === 'completed') {\n                clearInterval(progressInterval);\n                progressInterval = null;\n                \n                // 获取检测结果\n                fetchResults();\n            } else if (data.status === 'error') {\n                clearInterval(progressInterval);\n                progressInterval = null;\n                showMessage(`检测失败: ${data.error || '未知错误'}`, 'error');\n                updateUIState('uploaded');\n            }\n        } catch (error) {\n            console.error('获取进度时出错:', error);\n            showMessage(`获取进度失败: ${error.message}`, 'error');\n        }\n    }, 1000); // 每秒轮询一次\n}\n\n// 更新进度条\nfunction updateProgress(progress, message) {\n    const progressBar = document.getElementById('progress-bar');\n    const progressText = document.getElementById('progress-text');\n    const progressMessage = document.getElementById('progress-message');\n    \n    if (progressBar && progressText) {\n        progressBar.style.width = `${progress}%`;\n        progressText.textContent = `${progress}%`;\n    }\n    \n    if (progressMessage && message) {\n        progressMessage.textContent = message;\n    }\n}\n\n// 加载肺部切片数据\nfunction loadLungSegmentation() {\n    console.log('加载肺部切片数据...');\n    \n    // 检查会话ID是否存在\n    if (!currentSessionId) {\n        console.error('无法加载肺部切片: 没有会话ID');\n        return;\n    }\n    \n    // 显示加载中消息\n    document.getElementById('progress-message').textContent = '加载肺部切片数据...';\n    \n    // 加载切片信息\n    fetch(`/api/lung_slices_info/${currentSessionId}`)\n        .then(response => {\n            console.log('请求lung sliced info 返回了什么\\t', response);\n            if (!response.ok) {\n                throw new Error(`服务器返回错误状态: ${response.status}`);\n            }\n            return response.json();\n        })\n        .then(data => {\n            if (!data.success) {\n                throw new Error(data.error || '获取切片信息失败');\n            }\n            \n            console.log('获取到切片信息:', data.slices_info);\n            \n            // 设置全局变量\n            maxSliceIndex = data.slices_info.z_slices - 1;\n            // 将当前切片设置为中间位置\n            currentSliceIndex = Math.floor(maxSliceIndex / 2);\n            \n            // 初始化CT查看器\n            initCTViewer();\n            \n            // 加载初始切片\n            loadSlice(currentSliceIndex);\n            \n            // 设置状态\n            document.getElementById('progress-message').textContent = '肺部切片数据加载完成';\n            lungSegmentationLoaded = true;\n        })\n        .catch(error => {\n            console.error('获取肺部切片信息失败:', error);\n            document.getElementById('progress-message').textContent = `加载失败: ${error.message}`;\n            showMessage(`肺部切片数据加载失败: ${error.message}`, 'error');\n        });\n}\n\n// 初始化CT查看器\nfunction initCTViewer() {\n    console.log('初始化CT查看器...');\n    \n    // 获取查看器容器\n    const viewerContainer = document.getElementById('ct-viewer');\n    if (!viewerContainer) {\n        console.error('找不到CT查看器容器元素');\n        return;\n    }\n    \n    // 清空容器\n    viewerContainer.innerHTML = '';\n    \n    // 创建CT切片视图 - 只显示Z轴切片\n    const sliceView = document.createElement('div');\n    sliceView.className = 'slice-view';\n    sliceView.id = 'main-slice-view';\n    sliceView.innerHTML = `\n        <div class=\"slice-view-header\">横断面 (Z轴切片 - XY平面)</div>\n        <div class=\"slice-view-content\">\n            <div class=\"slice-image-container\">\n                <img id=\"slice-img\" src=\"\" alt=\"CT切片\">\n            </div>\n        </div>\n        <div class=\"slice-view-controls\">\n            <button class=\"slice-btn prev-slice\" id=\"prev-btn\"><i class=\"fas fa-chevron-left\"></i></button>\n            <span class=\"slice-index\" id=\"slice-index\">0 / 0</span>\n            <button class=\"slice-btn next-slice\" id=\"next-btn\"><i class=\"fas fa-chevron-right\"></i></button>\n        </div>\n    `;\n    \n    viewerContainer.appendChild(sliceView);\n    \n    // 添加滚动事件监听器\n    viewerContainer.addEventListener('wheel', handleSliceScroll);\n    \n    // 添加按钮事件监听器\n    document.getElementById('prev-btn').addEventListener('click', () => changeSlice(-1));\n    document.getElementById('next-btn').addEventListener('click', () => changeSlice(1));\n}\n\n// 处理滚轮事件\nfunction handleSliceScroll(event) {\n    event.preventDefault();\n    \n    // 确定滚动方向\n    const delta = Math.sign(event.deltaY);\n    \n    // 切换切片\n    changeSlice(delta);\n}\n\n// 切换切片\nfunction changeSlice(delta) {\n    // 计算新的切片索引\n    let newIndex = currentSliceIndex + delta;\n    \n    // 确保索引在有效范围内\n    newIndex = Math.max(0, Math.min(maxSliceIndex, newIndex));\n    \n    // 如果索引没有变化，不更新\n    if (newIndex === currentSliceIndex) {\n        return;\n    }\n    \n    // 更新当前索引\n    currentSliceIndex = newIndex;\n    \n    // 加载新切片\n    loadSlice(currentSliceIndex);\n}\n\n// 加载切片\nfunction loadSlice(sliceIndex) {\n    console.log(`加载Z轴切片，索引: ${sliceIndex}`);\n    \n    // 获取图像元素\n    const imgElement = document.getElementById('slice-img');\n    if (!imgElement) {\n        console.error('找不到切片图像元素');\n        return;\n    }\n    \n    // 更新切片索引显示\n    const indexElement = document.getElementById('slice-index');\n    if (indexElement) {\n        indexElement.textContent = `${sliceIndex} / ${maxSliceIndex}`;\n    }\n    \n    // 构建API URL\n    const url = `/api/lung_slice/${currentSessionId}/z/${sliceIndex}`;\n    // 添加时间戳防止缓存\n    const finalUrl = `${url}?t=${Date.now()}`;\n    // 设置加载事件\n    imgElement.onload = function() {\n        console.log('切片图像加载成功');\n    };\n    \n    imgElement.onerror = function() {\n        console.error('切片图像加载失败');\n        imgElement.src = ''; // 清空错误的图像\n        showMessage('切片图像加载失败', 'error');\n    };\n    \n    // 加载图像\n    imgElement.src = finalUrl;\n    \n    // 如果有结节标记功能，更新结节标记\n    if (window.renderNoduleMarkers) {\n        window.renderNoduleMarkers(sliceIndex);\n    }\n}\n\n// 获取检测结果\nfunction fetchResults() {\n    console.log('获取检测结果...');\n    \n    // 检查会话ID是否存在\n    if (!currentSessionId) {\n        console.error('无法获取结果: 没有会话ID');\n        return;\n    }\n    \n    // 更新UI状态\n    document.getElementById('progress-message').textContent = '加载检测结果...';\n    \n    // 加载肺部切片数据（如果尚未加载）\n    if (!lungSegmentationLoaded) {\n        loadLungSegmentation();\n    }\n    \n    // 获取结果\n    fetch(`/api/results/${currentSessionId}`)\n        .then(response => {\n            if (!response.ok) {\n                throw new Error(`服务器返回错误状态: ${response.status}`);\n            }\n            return response.json();\n        })\n        .then(data => {\n            console.log('检测结果:', data);\n            \n            // 更新UI状态\n            if (data.nodules && data.nodules.length > 0) {\n                updateUIState('detected');\n                document.getElementById('progress-message').textContent = `检测到 ${data.nodules.length} 个结节`;\n                \n                // 保存结节数据到全局变量\n                window.nodules = data.nodules;\n                \n                // 更新结节列表\n                updateNoduleList(data.nodules);\n                \n                // 创建结节标记\n                createNoduleMarkers();\n            } else {\n                updateUIState('detected');\n                document.getElementById('progress-message').textContent = '未检测到结节';\n                updateNoduleList([]);\n            }\n        })\n        .catch(error => {\n            console.error('获取结果出错:', error);\n            document.getElementById('progress-message').textContent = `获取结果失败: ${error.message}`;\n            showMessage(`获取结果出错: ${error.message}`, 'error');\n        });\n}\n\n// 更新结节列表\nfunction updateNoduleList(nodules) {\n    console.log('更新结节列表，数量:', nodules ? nodules.length : 0);\n    \n    const noduleListContainer = document.getElementById('nodule-list');\n    if (!noduleListContainer) {\n        console.error('找不到结节列表容器');\n        return;\n    }\n    \n    // 清空现有列表\n    noduleListContainer.innerHTML = '';\n    \n        // 更新结节计数\n    const noduleCountElement = document.getElementById('nodules-count');\n    if (noduleCountElement) {\n        noduleCountElement.textContent = nodules ? nodules.length : '0';\n    }\n    \n    // 如果没有结节\n    if (!nodules || nodules.length === 0) {\n        noduleListContainer.innerHTML = '<div class=\"no-nodules\">未检测到结节</div>';\n        return;\n    }\n    \n    // 添加结节到列表\n    nodules.forEach((nodule, index) => {\n        try {\n            const noduleItem = document.createElement('div');\n            noduleItem.className = 'nodule-item';\n            noduleItem.dataset.id = nodule.id || index;\n            \n            // 提取结节信息\n            const diameter = (nodule.diameter_mm || 10).toFixed(1);\n            const probability = ((nodule.probability || 0.5) * 100).toFixed(0);\n            \n            // 检查坐标数据\n            let coordsStr = '[未知]';\n            if (nodule.voxel_coords && Array.isArray(nodule.voxel_coords)) {\n                coordsStr = `[${nodule.voxel_coords.map(v => Math.round(v)).join(', ')}]`;\n            }\n            \n            // 结节名称和详细信息\n            noduleItem.innerHTML = `\n                <div class=\"nodule-header\">\n                    <span class=\"nodule-name\">结节 #${index + 1}</span>\n                    <span class=\"nodule-probability ${probability > 70 ? 'high' : (probability > 30 ? 'medium' : 'low')}\">\n                        ${probability}%\n                    </span>\n                </div>\n                <div class=\"nodule-details\">\n                    <div class=\"nodule-detail\">\n                        <span class=\"detail-label\">直径:</span>\n                        <span class=\"detail-value\">${diameter} mm</span>\n                    </div>\n                    <div class=\"nodule-detail\">\n                        <span class=\"detail-label\">位置:</span>\n                        <span class=\"detail-value\">${coordsStr}</span>\n                    </div>\n                </div>\n            `;\n            \n            // 添加到列表\n            noduleListContainer.appendChild(noduleItem);\n        } catch (error) {\n            console.error(`处理结节 #${index+1} 时出错:`, error);\n        }\n    });\n    \n    // 默认选中第一个结节\n        if (nodules.length > 0) {\n            const firstNoduleItem = noduleListContainer.querySelector('.nodule-item');\n            if (firstNoduleItem) {\n                firstNoduleItem.classList.add('selected');\n                const noduleId = firstNoduleItem.dataset.id;\n                loadNoduleDetails(noduleId);\n            }\n        }\n}\n\n// 创建结节标记\nfunction createNoduleMarkers() {\n    if (!window.nodules || !window.nodules.length) {\n        console.log('没有结节数据可用于创建标记');\n        return;\n    }\n    \n    console.log(`为 ${window.nodules.length} 个结节创建标记`);\n    \n    // 定义全局渲染函数\n    window.renderNoduleMarkers = function(currentZIndex) {\n        const imageContainer = document.querySelector('.slice-image-container');\n        const sliceImg = document.getElementById('slice-img');\n        \n        if (!imageContainer || !sliceImg || !sliceImg.complete) {\n            console.log('图像容器或图像未准备好，无法渲染结节标记');\n                return;\n        }\n        \n        // 清除所有现有标记\n        const existingMarkers = document.querySelectorAll('.nodule-marker');\n        existingMarkers.forEach(marker => marker.remove());\n        \n        // 获取图像尺寸\n        const imgWidth = sliceImg.width;\n        const imgHeight = sliceImg.height;\n        \n        if (imgWidth <= 0 || imgHeight <= 0) {\n            console.log('图像尺寸无效，无法渲染标记');\n            return;\n        }\n        \n        // 获取容器尺寸\n        const containerRect = imageContainer.getBoundingClientRect();\n        \n        // 为当前切片中的每个结节创建标记\n        window.nodules.forEach((nodule, index) => {\n            try {\n                // 检查结节是否有有效坐标\n                if (!nodule.voxel_coords || !Array.isArray(nodule.voxel_coords) || nodule.voxel_coords.length < 3) {\n                    console.warn(`结节 #${index+1} 没有有效坐标`);\n                    return;\n                }\n                \n                // 提取坐标\n                const [x, y, z] = nodule.voxel_coords.map(Math.round);\n                \n                // 仅显示当前Z切片上或附近的结节（容许一定误差）\n                const zTolerance = 1; // 允许的Z轴误差范围\n                if (Math.abs(z - currentZIndex) > zTolerance) {\n                    return;\n                }\n                \n                // 计算标记在图像上的位置（相对于图像尺寸）\n                // 注意：根据实际图像坐标系统可能需要调整\n                const relativeX = x / maxSliceIndex; // 假设x坐标范围与切片数量相关\n                const relativeY = y / maxSliceIndex; // 假设y坐标范围与切片数量相关\n                \n                // 将相对坐标转换为绝对像素坐标\n                const markerX = relativeX * imgWidth;\n                const markerY = relativeY * imgHeight;\n                \n                // 创建标记元素\n                const marker = document.createElement('div');\n                marker.className = 'nodule-marker';\n                marker.dataset.id = nodule.id || index;\n                \n                // 设置标记大小（根据结节直径调整）\n                const size = Math.max(16, (nodule.diameter_mm || 5) * 2);\n                marker.style.width = `${size}px`;\n                marker.style.height = `${size}px`;\n                \n                // 设置标记位置\n                marker.style.left = `${markerX}px`;\n                marker.style.top = `${markerY}px`;\n                \n                // 根据结节概率设置样式\n                const probability = (nodule.probability || 0.5) * 100;\n                if (probability > 70) {\n                    marker.classList.add('high-prob');\n                } else if (probability > 30) {\n                    marker.classList.add('medium-prob');\n                } else {\n                    marker.classList.add('low-prob');\n                }\n                // 添加标记编号\n                marker.innerHTML = `<span class=\"marker-label\">${index + 1}</span>`;\n                // 添加点击事件\n                marker.addEventListener('click', function() {\n                    // 高亮显示对应的结节列表项\n                    highlightNoduleInList(this.dataset.id);\n                    // 加载结节详情\n                    loadNoduleDetails(this.dataset.id);\n                });\n                \n                // 添加到容器\n                imageContainer.appendChild(marker);\n                \n            } catch (error) {\n                console.error(`为结节 #${index+1} 创建标记时出错:`, error);\n            }\n        });\n    };\n    \n    // 为当前切片渲染标记\n    window.renderNoduleMarkers(currentSliceIndex);\n}\n\n// 高亮显示结节列表中的特定结节\nfunction highlightNoduleInList(noduleId) {\n    // 移除所有当前选中项\n    document.querySelectorAll('.nodule-item.selected').forEach(item => {\n        item.classList.remove('selected');\n    });\n    \n    // 找到并选中指定结节\n    const noduleItem = document.querySelector(`.nodule-item[data-id=\"${noduleId}\"]`);\n    if (noduleItem) {\n        noduleItem.classList.add('selected');\n        \n        // 确保选中项在视图中可见\n        noduleItem.scrollIntoView({ behavior: 'smooth', block: 'nearest' });\n    }\n}\n\n// 加载结节详情\nfunction loadNoduleDetails(noduleId) {\n    console.log(`加载结节详情，ID: ${noduleId}`);\n    \n    if (!window.nodules) {\n        console.error('没有结节数据可用');\n        return;\n    }\n    \n    // 找到指定结节\n    const noduleIndex = parseInt(noduleId);\n    console.log('window nodule 有哪些数据\\n', window.nodules);\n    // 通过ID查找结节数据\n    let nodule = null;\n    // 将noduleId转换为整数\n    const targetId = parseInt(noduleId);\n    \n    // 遍历结节数组查找匹配的ID\n    for (let i = 0; i < window.nodules.length; i++) {\n        // 确保结节的id属性转换为整数进行比较\n        if (parseInt(window.nodules[i].id) === targetId) {\n            nodule = window.nodules[i];\n            console.log(`找到匹配的结节，ID: ${targetId}`);\n            break;\n        }\n    }\n    \n    // 如果找不到匹配的结节，则使用索引方式获取\n    if (!nodule) {\n        console.warn(`未找到ID为${targetId}的结节，尝试使用索引方式获取`);\n        // 尝试使用索引方式获取\n        nodule = window.nodules[noduleIndex];\n    }\n\n    if (!nodule) {\n        console.error(`找不到ID为 ${noduleId} 的结节`);\n        return;\n    }\n    \n    // 获取详情容器\n    const detailsContainer = document.getElementById('nodule-detail-area');\n    if (!detailsContainer) {\n        console.error('找不到结节详情容器');\n        return;\n    }\n    \n    // 清空当前详情\n    detailsContainer.innerHTML = '';\n    \n    try {\n        // 提取结节数据\n        const diameter = (nodule.diameter_mm || 10).toFixed(1);\n        const probability = ((nodule.probability || 0.5) * 100).toFixed(0);\n        const coords = nodule.voxel_coords || [0, 0, 0];\n        const [x, y, z] = coords.map(v => Math.round(v));\n        \n        // 如果结节有Z坐标，跳转到对应切片\n        if (typeof z === 'number' && z >= 0 && z <= maxSliceIndex) {\n            currentSliceIndex = z;\n            loadSlice(z);\n        }\n        \n        // 创建详情内容\n        detailsContainer.innerHTML = `\n            <h3>结节 #${noduleIndex + 1} 详情</h3>\n            <div class=\"detail-row\">\n                <span class=\"detail-label\">直径:</span>\n                <span class=\"detail-value\">${diameter} mm</span>\n            </div>\n            <div class=\"detail-row\">\n                <span class=\"detail-label\">恶性概率:</span>\n                <span class=\"detail-value probability-value ${probability > 70 ? 'high' : (probability > 30 ? 'medium' : 'low')}\">\n                    ${probability}%\n                </span>\n            </div>\n            <div class=\"detail-row\">\n                <span class=\"detail-label\">位置坐标:</span>\n                <span class=\"detail-value\">[${x}, ${y}, ${z}]</span>\n            </div>\n            <div class=\"detail-row\">\n                <span class=\"detail-label\">类型:</span>\n                <span class=\"detail-value\">${nodule.type || '未分类'}</span>\n            </div>\n            \n            <div class=\"actions-row\">\n                <button id=\"goto-nodule-btn\" class=\"btn btn-primary\">定位到结节</button>\n                <button id=\"nodule-report-btn\" class=\"btn btn-secondary\">查看报告</button>\n            </div>\n        `;\n        \n        // 添加按钮事件监听器\n        document.getElementById('goto-nodule-btn').addEventListener('click', function() {\n            // 定位到结节所在切片\n            if (typeof z === 'number' && z >= 0 && z <= maxSliceIndex) {\n                currentSliceIndex = z;\n                loadSlice(z);\n                \n                // 高亮显示结节标记\n                setTimeout(() => {\n                    const marker = document.querySelector(`.nodule-marker[data-id=\"${noduleId}\"]`);\n                    if (marker) {\n                        marker.classList.add('highlight');\n                        // 移除高亮状态\n                        setTimeout(() => {\n                            marker.classList.remove('highlight');\n                        }, 2000);\n                    }\n                }, 100);\n            }\n        });\n        \n        document.getElementById('nodule-report-btn').addEventListener('click', function() {\n            showMessage('结节报告功能尚未实现', 'info');\n        });\n        \n    } catch (error) {\n        console.error('加载结节详情出错:', error);\n        detailsContainer.innerHTML = `<div class=\"error-message\">加载结节详情失败: ${error.message}</div>`;\n    }\n}\n\n// 重置视图\nfunction resetView() {\n    console.log('重置视图');\n\n    // 如果没有数据，不执行任何操作\n    if (!lungSegmentationLoaded || maxSliceIndex <= 0) {\n        console.log('没有数据可重置');\n        return;\n    }\n\n    // 将当前切片索引重置为中间位置\n    currentSliceIndex = Math.floor(maxSliceIndex / 2);\n\n    // 加载中间切片\n    loadSlice(currentSliceIndex);\n\n    // 显示消息\n    showMessage('视图已重置', 'info');\n}\n\n// 显示消息\nfunction showMessage(message, type = 'info') {\n    console.log(`显示消息: ${message} (${type})`);\n\n    // 创建消息容器（如果不存在）\n    let msgContainer = document.getElementById('message-container');\n    if (!msgContainer) {\n        msgContainer = document.createElement('div');\n        msgContainer.id = 'message-container';\n        document.body.appendChild(msgContainer);\n    }\n\n    // 创建消息元素\n    const msgElement = document.createElement('div');\n    msgElement.className = `message ${type}`;\n    msgElement.innerHTML = `\n        <span class=\"message-icon\">\n            <i class=\"fas ${type === 'error' ? 'fa-exclamation-circle' : (type === 'success' ? 'fa-check-circle' : 'fa-info-circle')}\"></i>\n        </span>\n        <span class=\"message-text\">${message}</span>\n    `;\n\n    // 添加到容器\n    msgContainer.appendChild(msgElement);\n\n    // 显示消息\n    setTimeout(() => {\n        msgElement.classList.add('show');\n    }, 10);\n\n    // 设置自动消失\n    setTimeout(() => {\n        msgElement.classList.remove('show');\n        setTimeout(() => {\n            msgElement.remove();\n        }, 300);\n    }, 3000);\n}\n\n// 更新UI状态\nfunction updateUIState(state) {\n    // 移除所有状态类\n    document.body.classList.remove(\n        'state-initial',\n        'state-uploading',\n        'state-detecting',\n        'state-detected'\n    );\n\n    // 添加新状态类\n    document.body.classList.add(`state-${state}`);\n\n    // 更新UI元素可见性\n    switch (state) {\n        case 'initial':\n            document.getElementById('step-upload').classList.add('active');\n            document.getElementById('step-detect').classList.remove('active');\n            document.getElementById('step-visualize').classList.remove('active');\n            break;\n        case 'uploading':\n            document.getElementById('step-upload').classList.add('active');\n            document.getElementById('step-detect').classList.remove('active');\n            document.getElementById('step-visualize').classList.remove('active');\n            break;\n        case 'detecting':\n            document.getElementById('step-upload').classList.add('completed');\n            document.getElementById('step-detect').classList.add('active');\n            document.getElementById('step-visualize').classList.remove('active');\n            break;\n        case 'detected':\n            document.getElementById('step-upload').classList.add('completed');\n            document.getElementById('step-detect').classList.add('completed');\n            document.getElementById('step-visualize').classList.add('active');\n            break;\n    }\n}"
  },
  {
    "path": "deploy/run.py",
    "content": "import os\nimport sys\nimport subprocess\nimport platform\nimport webbrowser\nimport time\n\ndef get_python_command():\n    \"\"\"获取Python命令\"\"\"\n    if platform.system() == \"Windows\":\n        return \"python\"\n    else:\n        return \"python3\"\n\ndef check_dependencies():\n    \"\"\"检查必要的依赖是否已安装\"\"\"\n    try:\n        import flask\n        import numpy\n        import tensorflow\n        import SimpleITK\n        print(\"✓ 所有必要的依赖已安装\")\n        return True\n    except ImportError as e:\n        print(f\"✗ 缺少依赖: {e}\")\n        print(\"请安装所需依赖: pip install flask flask-cors numpy tensorflow SimpleITK\")\n        return False\n\ndef create_directories():\n    \"\"\"创建必要的目录\"\"\"\n    os.makedirs(\"backend/uploads\", exist_ok=True)\n    os.makedirs(\"backend/models\", exist_ok=True)\n    print(\"✓ 目录创建完成\")\n\ndef run_backend_server():\n    \"\"\"运行后端服务器\"\"\"\n    python_cmd = get_python_command()\n    \n    # 构建命令\n    cmd = [python_cmd, \"backend/app.py\"]\n    \n    # 启动后端服务器\n    print(\"\\n启动后端服务器...\")\n    process = subprocess.Popen(cmd)\n    \n    # 等待服务器启动\n    time.sleep(2)\n    \n    # 打开浏览器\n    print(\"正在打开浏览器...\")\n    webbrowser.open(\"http://localhost:5000\")\n    \n    print(\"\\n服务器已启动!\\n\")\n    print(\"在浏览器中访问: http://localhost:5000\")\n    print(\"按 Ctrl+C 停止服务器\")\n    \n    try:\n        process.wait()\n    except KeyboardInterrupt:\n        print(\"\\n正在停止服务器...\")\n        process.terminate()\n        process.wait()\n        print(\"服务器已停止\")\n\ndef main():\n    \"\"\"主函数\"\"\"\n    # 切换到脚本所在目录\n    script_dir = os.path.dirname(os.path.abspath(__file__))\n    os.chdir(script_dir)\n    \n    print(\"CT图像分析系统启动工具\")\n    print(\"=\" * 30)\n    \n    # 检查依赖\n    if not check_dependencies():\n        return\n    \n    # 创建必要的目录\n    create_directories()\n    \n    # 运行后端服务器\n    run_backend_server()\n\nif __name__ == \"__main__\":\n    main() "
  },
  {
    "path": "inference/__init__.py",
    "content": ""
  },
  {
    "path": "inference/c3d_classify_result-1.3.6.1.4.1.14519.5.2.1.6279.6001.149041668385192796520281592139.csv",
    "content": "patient_id,nodule_id,voxel_x,voxel_y,voxel_z,world_x,world_y,world_z,diameter_mm,prob\n1.3.6.1.4.1.14519.5.2.1.6279.6001.149041668385192796520281592139,1,253,256,81,58.0,76.0,-269.5,16.0,0.9999972581863403\n1.3.6.1.4.1.14519.5.2.1.6279.6001.149041668385192796520281592139,9,192,180,95,-3.0,0.0,-255.5,16.0,0.9858711361885071\n"
  },
  {
    "path": "inference/classifier.py",
    "content": "import numpy,pandas\nimport os\nfrom util import progress_watch\nfrom detector import extract_dicom_images_patient,get_papaya_coords,prepare_image_for_net3D,predict_nodule_type\nfrom keras.models import load_model, model_from_json\nfrom keras.optimizers import SGD\n\nfrom util.image_util import load_patient_images, rescale_patient_images\nfrom util.ml.metrics import get_3d_pixel_l2_distance\nfrom util.progress_watch import Stopwatch\n\nPREDICT_STEP = 12\nCUBE_SIZE = 32\nP_TH = 0.85\n\ndef scan(dicom_path, only_patient_id, workspace):\n    target_dir = workspace\n    boxes = []\n    centers = []\n    CONTINUE_JOB = True\n    sw = Stopwatch.start_new()\n    pixel_spacing, dicom_size, png_size, invert_order = extract_dicom_images_patient(dicom_path, target_dir)\n    print(\"png_size: \", png_size)\n    final_nodules_df = multipule_test(workspace, only_patient_id, CONTINUE_JOB)\n    final_nodules_df_sort = final_nodules_df.sort_values(['nodule_chance'], ascending=False)\n    print(\"predict from maligancy...\")\n    print(\"*\"*20)\n    print(final_nodules_df_sort)\n    print(\"*\" * 20)\n    if len(final_nodules_df_sort) == 0:\n        return boxes, centers\n\n    ggn_npy = predict_nodule_type(final_nodules_df_sort, png_size, workspace)\n    json_file = open('./models/c3d_malignancy_regreession.json', 'r')\n    loaded_model_json = json_file.read()\n    json_file.close()\n    loaded_model = model_from_json(loaded_model_json)\n\n    model_weight = './models/c3d_malignancy_regreession_04_0.8719.hd5'\n    loaded_model.load_weights(model_weight)\n    print(\"Loaded model_0 from disk\")\n    sgd = SGD(lr=0.0001, decay=1e-6, momentum=0.9, nesterov=True)\n    loaded_model.compile(loss='categorical_crossentropy', optimizer=sgd, metrics=['accuracy'])\n    model_result = loaded_model.predict(ggn_npy, batch_size=20, verbose=1)\n    ggn_class_list = []\n    for ii in model_result:\n        print(ii)\n        ii_index = numpy.argmax(ii)\n        print(\"result from malignancy..\")\n        print(generate_ggn_class(ii_index), round(ii[ii_index], 3))\n        # smaller than 0.5 means not malignancy\n        if round(ii[ii_index], 3)> 0.5:\n            ggn_class_list.append([generate_ggn_class(ii_index), round(ii[ii_index], 3)])\n\n    i = 0\n    for index, row in final_nodules_df_sort.iterrows():\n        print(ggn_class_list[i])\n        coord_z = row[\"coord_z\"]\n        coord_y = row[\"coord_y\"]\n        coord_x = row[\"coord_x\"]\n        print(\"index-x-y-z-p\", coord_x, coord_y, coord_z, row[\"nodule_chance\"])\n        box, center = get_papaya_coords(coord_x, coord_y, coord_z, row[\"nodule_chance\"], pixel_spacing, dicom_size,\n                                        png_size, invert_order, ggn_class_list[i])\n\n        boxes.append(box)\n        centers.append(center)\n        # draw_overlay(target_dir, coord_x, coord_y, coord_z, str(i))\n        # draw_overlay_dicom(pixels, only_patient_id, coord_x, coord_y, coord_z, str(i), pixel_spacing, dicom_size,\n        #                    png_size, invert_order, target_dir)\n        i += 1\n\n    print(\"ALL Complete in : \", sw.get_elapsed_seconds(), \" seconds\")\n    return boxes, centers\n\ndef generate_ggn_class(ii_index):\n    if ii_index == 0:\n        return 'non_malignancy'\n    if ii_index == 1:\n        return 'malignancy'\n\n\ndef multipule_test(workspace, only_patient_id, CONTINUE_JOB):\n    temp_df = []\n    for model_version in [\"model_loc.hd5\", \"model_loc_val_0.96.hd5\"]:\n        print(\"gpu begin:\")\n        pred_nodules_df = locate_malignancy(workspace, \"models/\" + model_version, CONTINUE_JOB, only_patient_id=only_patient_id,\n                                        magnification=1, flip=False, train_data=True, holdout_no=None,\n                                        ext_name=\"luna16_fs\")\n        pred_nodules_df = pred_nodules_df[pred_nodules_df[\"nodule_chance\"] > P_TH]\n        temp_df.append(pred_nodules_df)\n    temp_dataframe = pandas.concat(temp_df)\n    df = reduce_predicts_same_slice(temp_dataframe)\n    # df = temp_df\n    return df\n\ndef locate_malignancy(png_path, model_weight,CONTINUE_JOB, only_patient_id,\n                                        magnification=1, flip=False, train_data=True, holdout_no=None,\n                                        ext_name=\"luna16_fs\"):\n    patient_id = only_patient_id\n    all_predictions_csv = []\n    sw = helpers.Stopwatch.start_new()\n    #json_file = open('/home/xiatao/workspace/renji_dicom/model_json/c3d_5_label_classify.json', 'r')\n    json_file = open('../service/workdir/model_loc.json', 'r')\n    loaded_model_json = json_file.read()\n    json_file.close()\n    model = model_from_json(loaded_model_json)\n    # load weights into new model\n    model.load_weights(model_weight)\n    patient_img = load_patient_images(png_path + \"/png/\", \"*_i.png\", [])\n    if magnification != 1:\n        patient_img = rescale_patient_images(patient_img, (1, 1, 1), magnification)\n\n    patient_mask = load_patient_images(png_path + \"/png/\", \"*_m.png\", [])\n    if magnification != 1:\n        patient_mask = rescale_patient_images(patient_mask, (1, 1, 1), magnification, is_mask_image=True)\n\n    step = PREDICT_STEP\n    CROP_SIZE = CUBE_SIZE\n\n    predict_volume_shape_list = [0, 0, 0]\n    for dim in range(3):\n        dim_indent = 0\n        while dim_indent + CROP_SIZE < patient_img.shape[dim]:\n            predict_volume_shape_list[dim] += 1\n            dim_indent += step\n\n    predict_volume_shape = (predict_volume_shape_list[0], predict_volume_shape_list[1], predict_volume_shape_list[2])\n    predict_volume = numpy.zeros(shape=predict_volume_shape, dtype=float)\n    print(\"Predict volume shape: \", predict_volume.shape)\n    done_count = 0\n    skipped_count = 0\n    batch_size = 32\n    batch_list = []\n    batch_list_coords = []\n    patient_predictions_csv = []\n    cube_img = None\n    annotation_index = 0\n    for z in range(0, predict_volume_shape[0]):\n        for y in range(0, predict_volume_shape[1]):\n            for x in range(0, predict_volume_shape[2]):\n                # if cube_img is None:\n                cube_img = patient_img[z * step:z * step + CROP_SIZE, y * step:y * step + CROP_SIZE,\n                           x * step:x * step + CROP_SIZE]\n                cube_mask = patient_mask[z * step:z * step + CROP_SIZE, y * step:y * step + CROP_SIZE,\n                            x * step:x * step + CROP_SIZE]\n\n                if cube_mask.sum() < 2000:\n                    skipped_count += 1\n                else:\n                    if flip:\n                        cube_img = cube_img[:, :, ::-1]\n\n                    img_prep = prepare_image_for_net3D(cube_img)\n                    batch_list.append(img_prep)\n                    batch_list_coords.append((z, y, x))\n                    if len(batch_list) % batch_size == 0:\n                        batch_data = numpy.vstack(batch_list)\n                        p = model.predict(batch_data, batch_size=batch_size)\n                        for i in range(len(p[0])):\n                            p_z = batch_list_coords[i][0]\n                            p_y = batch_list_coords[i][1]\n                            p_x = batch_list_coords[i][2]\n                            # print(\"before malignancy_chance\")\n                            # print(p[0])\n                            malignancy_chance = p[0][i][0]\n                            predict_volume[p_z, p_y, p_x] = malignancy_chance\n                            if malignancy_chance > P_TH:\n                                p_z = p_z * step + CROP_SIZE / 2\n                                p_y = p_y * step + CROP_SIZE / 2\n                                p_x = p_x * step + CROP_SIZE / 2\n\n                                p_z_perc = round(p_z / patient_img.shape[0], 4)\n                                p_y_perc = round(p_y / patient_img.shape[1], 4)\n                                p_x_perc = round(p_x / patient_img.shape[2], 4)\n\n                                nodule_chance = round(malignancy_chance, 4)\n                                patient_predictions_csv_line = [annotation_index, p_x_perc, p_y_perc, p_z_perc,\n                                                                 nodule_chance]\n                                patient_predictions_csv.append(patient_predictions_csv_line)\n                                all_predictions_csv.append([patient_id] + patient_predictions_csv_line)\n                                annotation_index += 1\n\n                        batch_list = []\n                        batch_list_coords = []\n                done_count += 1\n                if done_count % 10000 == 0:\n                    print(\"Scan: \", done_count, \" skipped:\", skipped_count)\n\n    df = pandas.DataFrame(patient_predictions_csv,\n                          columns=[\"anno_index\", \"coord_x\", \"coord_y\", \"coord_z\",  \"nodule_chance\"])\n\n    filter_patient_nodules_predictions(df, patient_id, CROP_SIZE * magnification, png_path)\n\n    print(predict_volume.mean())\n    print(\"GPU costs : \", sw.get_elapsed_seconds(), \" seconds\")\n    return df\n\ndef filter_patient_nodules_predictions(df_nodule_predictions: pandas.DataFrame, patient_id, view_size, png_path):\n    patient_mask = load_patient_images(png_path+\"/png/\", \"*_m.png\")\n    delete_indices = []\n    for index, row in df_nodule_predictions.iterrows():\n        z_perc = row[\"coord_z\"]\n        y_perc = row[\"coord_y\"]\n        center_x = int(round(row[\"coord_x\"] * patient_mask.shape[2]))\n        center_y = int(round(y_perc * patient_mask.shape[1]))\n        center_z = int(round(z_perc * patient_mask.shape[0]))\n\n        start_y = center_y - view_size / 2\n        start_x = center_x - view_size / 2\n        nodule_in_mask = False\n        for z_index in [-1, 0, 1]:\n            img = patient_mask[z_index + center_z]\n            start_x = int(start_x)\n            start_y = int(start_y)\n            view_size = int(view_size)\n            img_roi = img[start_y:start_y + view_size, start_x:start_x + view_size]\n            if img_roi.sum() > 255:  # more than 1 pixel of mask.\n                nodule_in_mask = True\n\n        if not nodule_in_mask:\n            print(\"Nodule not in mask: \", (center_x, center_y, center_z))\n            #delete_indices.append(df_nodule_predictions.loc[index,])\n            delete_indices.append(index)\n        else:\n            if center_z < 30:\n                print(\"Z < 30: \", patient_id, \" center z:\", center_z, \" y_perc: \", y_perc)\n                #delete_indices.append(df_nodule_predictions.loc[index])\n                delete_indices.append(index)\n\n            if (z_perc > 0.75 or z_perc < 0.25) and y_perc > 0.85:\n                print(\"SUSPICIOUS FALSEPOSITIVE: \", patient_id, \" center z:\", center_z, \" y_perc: \", y_perc)\n                #delete_indices.append(df_nodule_predictions.loc[index])\n                delete_indices.append(index)\n\n            if center_z < 50 and y_perc < 0.30:\n                print(\"SUSPICIOUS FALSEPOSITIVE OUT OF RANGE: \", patient_id, \" center z:\", center_z, \" y_perc: \",\n                      y_perc)\n                #delete_indices.append(df_nodule_predictions.loc[index])\n                delete_indices.append(index)\n    print(\"slice to drop:\\t\",delete_indices)\n    df_nodule_predictions.drop(df_nodule_predictions.index[delete_indices], inplace=True)\n    return df_nodule_predictions\n\ndef reduce_predicts_same_slice(pred_nodules_df):\n    rows_filter = []\n    pred_nodules_df_local = pred_nodules_df.sort_values([\"coord_z\"], ascending=False)\n    if len(pred_nodules_df_local) <= 1:\n        return pred_nodules_df_local\n    compare_row = pred_nodules_df_local.iloc[0]\n    for row_index, row in pred_nodules_df_local[1:].iterrows():\n        if compare_row[\"coord_z\"] == row[\"coord_z\"]:\n            dist = get_3d_pixel_l2_distance(compare_row, row)\n            if dist > 0.2:\n                rows_filter.append(row)\n        else:\n            rows_filter.append(compare_row)\n            compare_row = row\n    if len(rows_filter) == 0:\n        rows_filter.append(compare_row)\n    last_row = rows_filter[len(rows_filter)-1]\n    if last_row[\"coord_z\"] != compare_row[\"coord_z\"]:\n        rows_filter.append(compare_row)\n    columns = [\"anno_index\", \"coord_x\", \"coord_y\", \"coord_z\", \"nodule_chance\"]\n    res_df = pandas.DataFrame(rows_filter, columns=columns)\n    return res_df"
  },
  {
    "path": "inference/detector.py",
    "content": "from service import settings_jjyang\nimport cv2\nimport pandas\nimport os\nimport glob\nimport numpy\nfrom keras import backend as K\nfrom keras.models import model_from_json\n\nfrom util.cube import save_cube_img, get_cube_from_img\nfrom util.image.processing import normalize_hu_values\nfrom util.dicom_util import get_pixels_hu, extract_dicom_images_patient, load_dicom_slices\nfrom util.image_util import prepare_image_for_net3D, load_patient_images, rescale_patient_images\nfrom util.ml.metrics import get_3d_pixel_l2_distance\nfrom util.progress_watch import Stopwatch\n\nK.set_image_data_format(\"channels_last\")  # 更新为TF2方式设置\nimport tensorflow as tf\n\n# 在TF2中设置GPU内存使用方式\ngpus = tf.config.experimental.list_physical_devices('GPU')\nif gpus:\n    try:\n        # 将GPU内存使用限制为可用内存的30%\n        for gpu in gpus:\n            tf.config.experimental.set_memory_growth(gpu, True)\n            tf.config.experimental.set_virtual_device_configuration(\n                gpu,\n                [tf.config.experimental.VirtualDeviceConfiguration(memory_limit=1024 * 3)]  # 约30%的10GB显存\n            )\n    except RuntimeError as e:\n        print(e)\n\nCUBE_SIZE = 32\nMEAN_PIXEL_VALUE = settings_jjyang.MEAN_PIXEL_VALUE_NODULE\nNEGS_PER_POS = 20\nP_TH = 0.7\n\nPREDICT_STEP = 12\nUSE_DROPOUT = False\n\nBOX_size = 20\nBOX_depth = 9\n# NODULE_CHANCE = 0.5\nNODULE_DIAMM = 1.0\nCUBE_IMGTYPE_SRC = \"_i\"\n\ndef filter_patient_nodules_predictions(df_nodule_predictions: pandas.DataFrame, patient_id, view_size, png_path):\n    patient_mask = load_patient_images(png_path+\"/png/\", \"*_m.png\")\n    delete_indices = []\n    for index, row in df_nodule_predictions.iterrows():\n        z_perc = row[\"coord_z\"]\n        y_perc = row[\"coord_y\"]\n        center_x = int(round(row[\"coord_x\"] * patient_mask.shape[2]))\n        center_y = int(round(y_perc * patient_mask.shape[1]))\n        center_z = int(round(z_perc * patient_mask.shape[0]))\n\n        mal_score = row[\"diameter_mm\"]\n        start_y = center_y - view_size / 2\n        start_x = center_x - view_size / 2\n        nodule_in_mask = False\n        for z_index in [-1, 0, 1]:\n            img = patient_mask[z_index + center_z]\n            start_x = int(start_x)\n            start_y = int(start_y)\n            view_size = int(view_size)\n            img_roi = img[start_y:start_y + view_size, start_x:start_x + view_size]\n            if img_roi.sum() > 255:  # more than 1 pixel of mask.\n                nodule_in_mask = True\n\n        if not nodule_in_mask:\n            print(\"Nodule not in mask: \", (center_x, center_y, center_z))\n            if mal_score > 0:\n                mal_score *= -1\n            df_nodule_predictions.loc[index, \"diameter_mm\"] = mal_score\n        else:\n            if center_z < 30:\n                print(\"Z < 30: \", patient_id, \" center z:\", center_z, \" y_perc: \", y_perc)\n                if mal_score > 0:\n                    mal_score *= -1\n                df_nodule_predictions.loc[index, \"diameter_mm\"] = mal_score\n\n            if (z_perc > 0.75 or z_perc < 0.25) and y_perc > 0.85:\n                print(\"SUSPICIOUS FALSEPOSITIVE: \", patient_id, \" center z:\", center_z, \" y_perc: \", y_perc)\n\n            if center_z < 50 and y_perc < 0.30:\n                print(\"SUSPICIOUS FALSEPOSITIVE OUT OF RANGE: \", patient_id, \" center z:\", center_z, \" y_perc: \",\n                      y_perc)\n\n    df_nodule_predictions.drop(df_nodule_predictions.index[delete_indices], inplace=True)\n    return df_nodule_predictions\n\n\ndef predict_cubes(png_path, model_path, only_patient_id=None,  magnification=1, flip=False):\n    patient_id = only_patient_id\n    all_predictions_csv = []\n    sw = helpers.Stopwatch.start_new()\n    json_file = open('../service/workdir/model_loc.json', 'r')\n    loaded_model_json = json_file.read()\n    json_file.close()\n    model = model_from_json(loaded_model_json)\n    # load weights into new model\n    model.load_weights(model_path)\n    patient_img = load_patient_images(png_path+\"/png/\", \"*_i.png\", [])\n    if magnification != 1:\n        patient_img = rescale_patient_images(patient_img, (1, 1, 1), magnification)\n\n    patient_mask = load_patient_images(png_path+\"/png/\", \"*_m.png\", [])\n    if magnification != 1:\n        patient_mask = rescale_patient_images(patient_mask, (1, 1, 1), magnification, is_mask_image=True)\n\n    step = PREDICT_STEP\n    CROP_SIZE = CUBE_SIZE\n\n    predict_volume_shape_list = [0, 0, 0]\n    for dim in range(3):\n        dim_indent = 0\n        while dim_indent + CROP_SIZE < patient_img.shape[dim]:\n            predict_volume_shape_list[dim] += 1\n            dim_indent += step\n\n    predict_volume_shape = (predict_volume_shape_list[0], predict_volume_shape_list[1], predict_volume_shape_list[2])\n    predict_volume = numpy.zeros(shape=predict_volume_shape, dtype=float)\n    print(\"Predict volume shape: \", predict_volume.shape)\n    done_count = 0\n    skipped_count = 0\n    batch_size = 32\n    batch_list = []\n    batch_list_coords = []\n    patient_predictions_csv = []\n    cube_img = None\n    annotation_index = 0\n    for z in range(0, predict_volume_shape[0]):\n        for y in range(0, predict_volume_shape[1]):\n            for x in range(0, predict_volume_shape[2]):\n                # if cube_img is None:\n                cube_img = patient_img[z * step:z * step + CROP_SIZE, y * step:y * step + CROP_SIZE,\n                           x * step:x * step + CROP_SIZE]\n                cube_mask = patient_mask[z * step:z * step + CROP_SIZE, y * step:y * step + CROP_SIZE,\n                            x * step:x * step + CROP_SIZE]\n\n                if cube_mask.sum() < 2000:\n                    skipped_count += 1\n                else:\n                    if flip:\n                        cube_img = cube_img[:, :, ::-1]\n\n                    img_prep = prepare_image_for_net3D(cube_img)\n                    batch_list.append(img_prep)\n                    batch_list_coords.append((z, y, x))\n                    if len(batch_list) % batch_size == 0:\n                        batch_data = numpy.vstack(batch_list)\n                        p = model.predict(batch_data, batch_size=batch_size)\n                        for i in range(len(p[0])):\n                            p_z = batch_list_coords[i][0]\n                            p_y = batch_list_coords[i][1]\n                            p_x = batch_list_coords[i][2]\n                            nodule_chance = p[0][i][0]\n                            predict_volume[p_z, p_y, p_x] = nodule_chance\n                            if nodule_chance > P_TH:\n                                p_z = p_z * step + CROP_SIZE / 2\n                                p_y = p_y * step + CROP_SIZE / 2\n                                p_x = p_x * step + CROP_SIZE / 2\n\n                                p_z_perc = round(p_z / patient_img.shape[0], 4)\n                                p_y_perc = round(p_y / patient_img.shape[1], 4)\n                                p_x_perc = round(p_x / patient_img.shape[2], 4)\n                                diameter_mm = round(p[1][i][0], 4)\n                                diameter_perc = round(diameter_mm / patient_img.shape[2], 4)\n                                nodule_chance = round(nodule_chance, 4)\n                                patient_predictions_csv_line = [annotation_index, p_x_perc, p_y_perc, p_z_perc,\n                                                                diameter_perc, nodule_chance, diameter_mm]\n                                patient_predictions_csv.append(patient_predictions_csv_line)\n                                all_predictions_csv.append([patient_id] + patient_predictions_csv_line)\n                                annotation_index += 1\n\n                        batch_list = []\n                        batch_list_coords = []\n                done_count += 1\n                if done_count % 10000 == 0:\n                    print(\"Scan: \", done_count, \" skipped:\", skipped_count)\n\n    df = pandas.DataFrame(patient_predictions_csv,\n                          columns=[\"anno_index\", \"coord_x\", \"coord_y\", \"coord_z\", \"diameter\", \"nodule_chance\",\n                                   \"diameter_mm\"])\n    filter_patient_nodules_predictions(df, patient_id, CROP_SIZE * magnification, png_path)\n    # df.to_csv(settings_jjyang.BASE_DIR_SSD+\"temp_dir/\" + patient_id + \".csv\", index=False)\n    print(predict_volume.mean())\n    print(\"GPU costs : \", sw.get_elapsed_seconds(), \" seconds\")\n    return df\n\ndef reduce_predicts_same_slice(pred_nodules_df):\n    rows_filter = []\n    pred_nodules_df_local = pred_nodules_df.sort_values([\"coord_z\", \"diameter_mm\"], ascending=False)\n    if len(pred_nodules_df_local) <= 1:\n        return pred_nodules_df_local\n    i = 0\n    compare_row = pred_nodules_df_local.iloc[0]\n    for row_index, row in pred_nodules_df_local[1:].iterrows():\n        if compare_row[\"coord_z\"] == row[\"coord_z\"]:\n            dist = get_3d_pixel_l2_distance(compare_row, row)\n            if dist > 0.2:\n                rows_filter.append(row)\n        else:\n            rows_filter.append(compare_row)\n            compare_row = row\n        i += 1\n    if len(rows_filter) == 0:\n        rows_filter.append(compare_row)\n    last_row = rows_filter[len(rows_filter)-1]\n    if last_row[\"coord_z\"] != compare_row[\"coord_z\"]:\n        rows_filter.append(compare_row)\n    columns = [\"anno_index\", \"coord_x\", \"coord_y\", \"coord_z\", \"diameter\", \"nodule_chance\", \"diameter_mm\"]\n    res_df = pandas.DataFrame(rows_filter, columns=columns)\n    return res_df\n\n\ndef multipule_test(workspace, only_patient_id, CONTINUE_JOB):\n    temp_df = []\n    for model_version in [\"model_loc.hd5\", \"model_loc_val_0.96.hd5\"]:\n        print(\"gpu begin:\")\n        pred_nodules_df = predict_cubes(workspace, \"models/\" + model_version, CONTINUE_JOB, only_patient_id=only_patient_id,\n                                        magnification=1, flip=False, train_data=True, holdout_no=None,\n                                        ext_name=\"luna16_fs\")\n        pred_nodules_df = pred_nodules_df[pred_nodules_df[\"nodule_chance\"] > P_TH]\n        pred_nodules_df = pred_nodules_df[pred_nodules_df[\"diameter_mm\"] > NODULE_DIAMM]\n\n        temp_df.append(pred_nodules_df)\n    temp_dataframe = pandas.concat(temp_df)\n    df = reduce_predicts_same_slice(temp_dataframe)\n    # df = temp_df\n    return df\n\n\ndef draw_overlay_dicom(pixels,  coord_x, coord_y, coord_z, i, pixel_spacing, dicom_size, png_size, invert_order, png_path):\n    z = int(coord_z * png_size[0])\n    y = int(coord_y * png_size[1])\n    x = int(coord_x * png_size[2])\n    dicom_z = int(z / pixel_spacing[2])\n    dicom_y = int(y / pixel_spacing[1])\n    dicom_x = int(x / pixel_spacing[0])\n    x1 = dicom_x - BOX_size\n    y1 = dicom_y - BOX_size\n    x2 = dicom_x + BOX_size\n    y2 = dicom_y + BOX_size\n    print(\"invert_order:\", invert_order)\n    if invert_order:\n        new_z = dicom_z\n        org_img = pixels[new_z]\n        print(\"dicom_coord_x_y_z: \", dicom_x, dicom_y, new_z)\n    else:\n        new_z = dicom_size[0] - dicom_z\n        org_img = pixels[new_z]\n        # print(\"dicom_coord_x_y_z: \", dicom_size[2] - dicom_x, dicom_y, new_z)  #for papaya reverse left-right\n        print(\"dicom_coord_x_y_z: \", dicom_x, dicom_y, new_z)\n\n    org_img = normalize_hu_values(org_img)\n    cv2.rectangle(org_img, (x1, y1), (x2, y2), (255, 0, 0), 1)\n    # suffix = i + \"_\" + str(dicom_size[0] - dicom_z)\n    suffix = i + \"_\" + str(new_z)\n    cv2.imwrite(png_path + \"/\" + \"overlay_dicom\" + suffix + \".png\", org_img * 255)\n\n\ndef get_papaya_coords(coord_x, coord_y, coord_z, nodule_chance, pixel_spacing, dicom_size, png_size, invert_order, ggn_class):\n    z = int(coord_z * png_size[0])\n    y = int(coord_y * png_size[1])\n    x = int(coord_x * png_size[2])\n    dicom_z = int(z / pixel_spacing[2])\n    dicom_y = int(y / pixel_spacing[1])\n    dicom_x = int(x / pixel_spacing[0])\n\n    # print(\"invert_order:\", invert_order)\n    if invert_order:\n        new_z = dicom_size[0] - dicom_z\n        new_x = dicom_size[2] - dicom_x\n        new_y = dicom_y\n        print(\"invert_order: \", invert_order, \"new_coord_z_y_x: \", new_z, new_y, new_x, \" dicom_coord_z_y_x: \", dicom_z, dicom_y, dicom_x)\n    else:\n        new_z = dicom_size[0] - dicom_z\n        new_x = dicom_size[2] - dicom_x\n        new_y = dicom_y\n        # print(\"dicom_coord_x_y_z: \", dicom_size[2] - dicom_x, dicom_y, new_z)  #for papaya reverse left-right\n        print(\"invert_order: \", invert_order, \"new_coord_z_y_x: \", new_z, new_y, new_x,\" dicom_coord_z_y_x: \", dicom_z, dicom_y, dicom_x)\n    x1 = new_x - BOX_size\n    y1 = new_y - BOX_size\n    x2 = new_x + BOX_size\n    y2 = new_y + BOX_size\n    z1 = new_z - BOX_depth\n    z2 = new_z + BOX_depth\n    box = [z1, y1, x1, z2, y2, x2]\n    center = [new_z, new_y, new_x, nodule_chance*100, ggn_class[0], ggn_class[1]*100]\n    return box, center\n\n\ndef get_papaya_coords_only(coord_x, coord_y, coord_z, nodule_chance, pixel_spacing, dicom_size, png_size, invert_order):\n    z = int(coord_z * png_size[0])\n    y = int(coord_y * png_size[1])\n    x = int(coord_x * png_size[2])\n    dicom_z = int(z / pixel_spacing[2])\n    dicom_y = int(y / pixel_spacing[1])\n    dicom_x = int(x / pixel_spacing[0])\n\n    # print(\"invert_order:\", invert_order)\n    if invert_order:\n        new_z = dicom_size[0] - dicom_z\n        new_x = dicom_size[2] - dicom_x\n        new_y = dicom_y\n        print(\"invert_order: \", invert_order, \"new_coord_z_y_x: \", new_z, new_y, new_x, \" dicom_coord_z_y_x: \", dicom_z, dicom_y, dicom_x)\n    else:\n        new_z = dicom_size[0] - dicom_z\n        new_x = dicom_size[2] - dicom_x\n        new_y = dicom_y\n        # print(\"dicom_coord_x_y_z: \", dicom_size[2] - dicom_x, dicom_y, new_z)  #for papaya reverse left-right\n        print(\"invert_order: \", invert_order, \"new_coord_z_y_x: \", new_z, new_y, new_x,\" dicom_coord_z_y_x: \", dicom_z, dicom_y, dicom_x)\n    x1 = new_x - BOX_size\n    y1 = new_y - BOX_size\n    x2 = new_x + BOX_size\n    y2 = new_y + BOX_size\n    z1 = new_z - BOX_depth\n    z2 = new_z + BOX_depth\n    box = [z1, y1, x1, z2, y2, x2]\n    center = [new_z, new_y, new_x, round(nodule_chance,3)]\n    return box, center\n\n\ndef run(dicom_path, png_save_dir):\n    CONTINUE_JOB = True\n    sw = Stopwatch.start_new()\n    pixel_spacing, dicom_size, png_size, invert_order = extract_dicom_images_patient(dicom_path, png_save_dir)\n    final_nodules_df = multipule_test(png_save_dir, CONTINUE_JOB)\n    # final_nodules_df = pandas.read_csv(settings_jjyang.OVERLAY_PATH + only_patient_id + \".csv\")\n    slices = load_dicom_slices(dicom_path)\n    pixels = get_pixels_hu(slices)\n    i = 0\n    for index, row in final_nodules_df.iterrows():\n        coord_z = row[\"coord_z\"]\n        coord_y = row[\"coord_y\"]\n        coord_x = row[\"coord_x\"]\n        print(\"index-x-y-z-p-size\", i, coord_x, coord_y, coord_z, row[\"nodule_chance\"], row[\"diameter_mm\"])\n        draw_overlay_dicom(pixels, coord_x, coord_y, coord_z, str(i), pixel_spacing, dicom_size, png_size, invert_order)\n        # draw_overlay(only_patient_id, coord_x, coord_y, coord_z, str(i))\n        i += 1\n    # reduce_predicts_same_slice(pred_nodules_df)\n    print(\"ALL Complete in : \", sw.get_elapsed_seconds(), \" seconds\")\n\n\ndef predict_nodule_type(df, png_size, png_dir):\n    list = []\n    new_dir = png_dir + '/png/'\n    images = load_patient_images(new_dir, \"*\" + CUBE_IMGTYPE_SRC + \".png\")\n    i = 0\n    for index, row in df.iterrows():\n        coord_z = row[\"coord_z\"]\n        coord_y = row[\"coord_y\"]\n        coord_x = row[\"coord_x\"]\n        z = int(coord_z * png_size[0])\n        y = int(coord_y * png_size[1])\n        x = int(coord_x * png_size[2])\n        print(\"index-x-y-z-p-size-png\", x, y, z)\n        # print('z invert order : ', invert_order)\n        # if not invert_order:\n        #     coord_z = int((dicom_size[0] - row[\"z\"]) * pixel_spacing[2])\n        # else:\n        #     coord_z = int(row[\"z\"] * pixel_spacing[2])\n        cube_img = get_cube_from_img(images, x, y, z, 32)\n        # save_cube_img('./' + png_dir + '/_' + str(i) + '.png', cube_img, 4, 8)\n        save_cube_img(png_dir + '/_' + str(x)+'_'+str(y)+'_'+str(z) + '.png', cube_img, 4, 8)\n        img3d = prepare_image_for_net3D(cube_img)\n        list.append(img3d)\n        i += 1\n\n    img_numpy = numpy.vstack(list)\n    return img_numpy\n\ndef generate_ggn_class(ii_index):\n    if ii_index == 0:\n        return 'AAH'\n    if ii_index == 1:\n        return 'AIS'\n    if ii_index == 2:\n        return 'MIA'\n    if ii_index == 3:\n        return 'IA'\n    if ii_index == 4:\n        return 'OH'\n\n\ndef scan(dicom_path, only_patient_id, workspace, file_type=\"dicom\"):\n    \"\"\"\n    扫描节点，支持DICOM和MHD文件\n    \n    :param dicom_path: DICOM目录或MHD文件路径\n    :param only_patient_id: 患者ID\n    :param workspace: 工作目录\n    :param file_type: 文件类型，'dicom'或'mhd'\n    :return: 节点包围盒和中心点\n    \"\"\"\n    print(\"workspace\", workspace)\n    try:\n        target_dir = workspace\n        p_index = 0\n        # 提取图像\n        if file_type == \"dicom\":\n            # 从DICOM提取图像\n            pixel_spacing, dicom_size, png_size, invert_order = extract_dicom_images_patient(dicom_path,  target_dir)\n        else:\n            # 对于MHD文件，图像已经在前面的处理步骤中提取\n            pixel_spacing = [1.0, 1.0, 1.0]  # 默认值，后续可根据需要从MHD信息中提取\n            dicom_size = [0, 0, 0]  # 默认值\n            png_size = [0, 0, 0]    # 默认值\n            invert_order = False\n            \n            # 尝试从PNG目录中获取实际尺寸\n            png_dir = os.path.join(target_dir, 'png')\n            if os.path.exists(png_dir):\n                png_files = glob.glob(png_dir + \"/*_i.png\")\n                if png_files:\n                    sample_img = cv2.imread(png_files[0], cv2.IMREAD_GRAYSCALE)\n                    png_size = [len(png_files), sample_img.shape[0], sample_img.shape[1]]\n        \n        # 针对特定模型进行预测\n        model_path = os.path.join(current_dir, '../service/workdir/jjnode_classify_resnet_best.h5')\n        df_pos = predict_cubes(target_dir, model_path, True, only_patient_id, False, 1, False)\n        # 添加分类处理\n        df_node = predict_nodule_type(df_pos, png_size, target_dir)\n        df_pos = df_node\n\n        # 获取框和中心点信息\n        boxes = []\n        record_list = []\n        records_png_paths = []\n        for index, row in df_pos.iterrows():\n            record = []\n            coord_x = row[\"coord_x\"]\n            coord_y = row[\"coord_y\"]\n            coord_z = row[\"coord_z\"]\n            nodule_chance = float(row[\"nodule_chance\"])\n            box_label = row[\"box_label\"]\n\n            # 绘制可视化\n            draw_overlay(target_dir, coord_x, coord_y, coord_z, str(p_index))\n            \n            try:\n                # 获取节点坐标\n                box = get_papaya_coords(\n                    coord_x, coord_y, coord_z, nodule_chance, pixel_spacing, dicom_size, png_size, invert_order, box_label)\n                boxes.append(box)\n\n                # 收集节点信息\n                record.append(str(p_index))\n                record.append(coord_x)\n                record.append(coord_y)\n                record.append(coord_z)\n                record.append(box_label)\n                record.append(nodule_chance)\n                record_list.append(record)\n                \n                # 生成立方体图像\n                png_f = make_test_cube(record)\n                records_png_paths.append(png_f)\n                \n            except Exception as e:\n                print(f\"处理节点时出错: {e}\")\n                \n            p_index += 1\n            \n        return boxes, record_list\n        \n    except Exception as e:\n        print(f\"扫描过程中出错: {e}\")\n        return [], []\n\n\ndef scan_only(dicom_path, only_patient_id, workspace, file_type=\"dicom\"):\n    \"\"\"\n    只扫描节点，不分类，支持DICOM和MHD文件\n    \n    :param dicom_path: DICOM目录或MHD文件路径\n    :param only_patient_id: 患者ID\n    :param workspace: 工作目录\n    :param file_type: 文件类型，'dicom'或'mhd'\n    :return: 节点包围盒和中心点\n    \"\"\"\n    print(\"workspace\", workspace)\n    try:\n        target_dir = workspace\n        model_path = os.path.join(current_dir, '../service/workdir/jjnode_loc_resnet_best.h5')\n        p_index = 0\n        \n        # 提取图像\n        if file_type == \"dicom\":\n            # 从DICOM提取图像\n            pixel_spacing, dicom_size, png_size, invert_order = extract_dicom_images_patient(dicom_path, target_dir)\n        else:\n            # 对于MHD文件，图像已经在前面的处理步骤中提取\n            pixel_spacing = [1.0, 1.0, 1.0]  # 默认值，后续可根据需要从MHD信息中提取\n            dicom_size = [0, 0, 0]  # 默认值\n            png_size = [0, 0, 0]    # 默认值\n            invert_order = False\n            \n            # 尝试从PNG目录中获取实际尺寸\n            png_dir = os.path.join(target_dir, 'png')\n            if os.path.exists(png_dir):\n                png_files = glob.glob(png_dir + \"/*_i.png\")\n                if png_files:\n                    sample_img = cv2.imread(png_files[0], cv2.IMREAD_GRAYSCALE)\n                    png_size = [len(png_files), sample_img.shape[0], sample_img.shape[1]]\n        \n        # 预测节点\n        df_pos = predict_cubes(target_dir, model_path, True, only_patient_id, False, 1, False)\n        \n        # 过滤预测结果\n        df_pos = filter_patient_nodules_predictions(df_pos, only_patient_id, BOX_size, target_dir)\n        df_pos = reduce_predicts_same_slice(df_pos)\n        \n        # 获取框和中心点信息\n        boxes = []\n        centers = []\n        for index, row in df_pos.iterrows():\n            coord_x = row[\"coord_x\"]\n            coord_y = row[\"coord_y\"]\n            coord_z = row[\"coord_z\"]\n            nodule_chance = float(row[\"nodule_chance\"])\n            \n            # 绘制可视化\n            draw_overlay(target_dir, coord_x, coord_y, coord_z, str(p_index))\n            \n            try:\n                # 获取节点坐标\n                box = get_papaya_coords_only(\n                    coord_x, coord_y, coord_z, nodule_chance, pixel_spacing, dicom_size, png_size, invert_order)\n                boxes.append(box)\n                \n                # 节点中心点（标记百分比）\n                z_perc = coord_z\n                y_perc = coord_y\n                x_perc = coord_x\n                nodule_chance_perc = nodule_chance\n                diamm = NODULE_DIAMM\n                center = [int(z_perc * dicom_size[0]), int(y_perc * dicom_size[1]), int(x_perc * dicom_size[2]), diamm, nodule_chance_perc * 100]\n                centers.append(center)\n                \n            except Exception as e:\n                print(f\"处理节点时出错: {e}\")\n                \n            p_index += 1\n            \n        return boxes, centers\n        \n    except Exception as e:\n        print(f\"扫描过程中出错: {e}\")\n        return [], []\n\n\n# horos_5358\nif __name__ == \"__main__\":\n    workspace = os.path.join(os.path.abspath(os.path.dirname(__file__)),\n                             './static/workspace/KHAXVYV5')\n    scan('E:/renji_hospital_dicom/AIS/ChenZhou/IE2UNXKC/KHAXVYV5',\n         'KHAXVYV5', workspace)\n    # df = pandas.read_csv(\"./test.csv\")\n    # predict_nodule_type(df,[354, 305, 305],'./static/workspace/0708c00f6117ed977bbe1b462b56848c')\n\n"
  },
  {
    "path": "inference/negative_sample_selection.py",
    "content": "import os\nimport sys\nimport numpy as np\nimport random\nimport glob\nfrom sklearn.cluster import KMeans\nimport matplotlib.pyplot as plt\nfrom scipy import ndimage\n# 导入CTData类\nfrom data.dataclass.CTData import CTData\n# 立方体大小\nCUBE_SIZE = 32\ndef variance_of_laplacian(image):\n    \"\"\"计算图像的Laplacian方差，用于评估图像清晰度/模糊度\"\"\"\n    # 计算图像的Laplacian\n    laplacian = ndimage.laplace(image)\n    # 返回Laplacian的方差\n    return np.var(laplacian)\n\ndef texture_score(cube):\n    \"\"\"计算立方体的纹理分数，用于评估是否包含足够的纹理特征\"\"\"\n    # 方法1: 使用Laplacian方差\n    lap_scores = []\n    for i in range(cube.shape[0]):\n        lap_scores.append(variance_of_laplacian(cube[i]))\n    lap_score = np.mean(lap_scores)\n    \n    # 方法2: 使用梯度幅值\n    grad_x = ndimage.sobel(cube, axis=2)\n    grad_y = ndimage.sobel(cube, axis=1)\n    grad_z = ndimage.sobel(cube, axis=0)\n    grad_mag = np.sqrt(grad_x**2 + grad_y**2 + grad_z**2)\n    grad_score = np.mean(grad_mag)\n    \n    # 方法3: 使用标准差（简单但有效）\n    std_score = np.std(cube)\n    \n    # 组合分数\n    combined_score = (lap_score * 0.3) + (grad_score * 0.3) + (std_score * 0.4)\n    return combined_score\n\ndef is_valid_negative_sample(cube, lung_mask=None, nodule_coords=None, min_distance=20, min_texture_score=0.1):\n    \"\"\"\n    检查立方体是否是有效的负样本\n    \n    Args:\n        cube: 32x32x32立方体数据\n        lung_mask: 肺部掩码，如果提供则检查是否在肺部内\n        nodule_coords: 已知结节坐标列表，如果提供则检查是否距离已知结节足够远\n        min_distance: 与已知结节的最小欧氏距离\n        min_texture_score: 最小纹理分数阈值\n        \n    Returns:\n        布尔值，表示是否是有效的负样本\n    \"\"\"\n    # 计算纹理分数\n    score = texture_score(cube)\n    if score < min_texture_score:\n        return False\n    \n    # 检查是否在肺部内\n    if lung_mask is not None:\n        if np.mean(lung_mask) < 0.6:  # 要求至少60%的体素在肺部内\n            return False\n    \n    # 检查是否距离已知结节足够远\n    if nodule_coords is not None and len(nodule_coords) > 0:\n        cube_center = np.array([CUBE_SIZE//2, CUBE_SIZE//2, CUBE_SIZE//2])\n        for nodule_coord in nodule_coords:\n            distance = np.linalg.norm(cube_center - np.array(nodule_coord))\n            if distance < min_distance:\n                return False\n    \n    return True\n\ndef select_negative_samples_from_ct(ct_data, nodule_coords=None, num_samples=100, strategy='random'):\n    \"\"\"\n    从CT数据中选择负样本\n    \n    Args:\n        ct_data: CTData对象\n        nodule_coords: 已知结节坐标列表\n        num_samples: 要选择的负样本数量\n        strategy: 选择策略，'random'(随机)或'kmeans'(聚类)\n        \n    Returns:\n        负样本列表，每个样本为32x32x32的立方体\n    \"\"\"\n    if ct_data.lung_seg_img is None:\n        print(\"肺部区域数据未分割，正在分割...\")\n        ct_data.filter_lung_img_mask()\n    \n    lung_img = ct_data.lung_seg_img\n    lung_mask = ct_data.lung_seg_mask\n    \n    # 获取肺部边界\n    z_indices, y_indices, x_indices = np.where(lung_mask > 0)\n    if len(z_indices) == 0:\n        print(\"未找到肺部区域\")\n        return []\n    \n    z_min, z_max = z_indices.min(), z_indices.max()\n    y_min, y_max = y_indices.min(), y_indices.max()\n    x_min, x_max = x_indices.min(), x_indices.max()\n    \n    # 调整范围，确保可以放置完整的立方体\n    z_min = max(0, z_min)\n    y_min = max(0, y_min)\n    x_min = max(0, x_min)\n    z_max = min(lung_img.shape[0] - CUBE_SIZE, z_max)\n    y_max = min(lung_img.shape[1] - CUBE_SIZE, y_max)\n    x_max = min(lung_img.shape[2] - CUBE_SIZE, x_max)\n    \n    # 收集候选点\n    if strategy == 'random':\n        # 随机选择策略\n        candidate_samples = []\n        attempts = 0\n        max_attempts = num_samples * 10  # 最大尝试次数\n        \n        while len(candidate_samples) < num_samples and attempts < max_attempts:\n            # 随机选择起始点\n            z = random.randint(z_min, z_max)\n            y = random.randint(y_min, y_max)\n            x = random.randint(x_min, x_max)\n            \n            # 提取立方体和肺部掩码\n            cube = lung_img[z:z+CUBE_SIZE, y:y+CUBE_SIZE, x:x+CUBE_SIZE]\n            cube_mask = lung_mask[z:z+CUBE_SIZE, y:y+CUBE_SIZE, x:x+CUBE_SIZE]\n            \n            # 检查是否是有效的负样本\n            if is_valid_negative_sample(cube, cube_mask, nodule_coords):\n                candidate_samples.append({\n                    'cube': cube,\n                    'position': (z, y, x),\n                    'score': texture_score(cube)\n                })\n            \n            attempts += 1\n        \n        # 根据纹理分数排序，选择最佳样本\n        candidate_samples.sort(key=lambda x: x['score'], reverse=True)\n        selected_samples = candidate_samples[:num_samples]\n        \n        return [sample['cube'] for sample in selected_samples]\n    \n    elif strategy == 'kmeans':\n        # 聚类选择策略（适用于大规模训练）\n        # 首先生成大量候选点\n        all_candidates = []\n        for _ in range(num_samples * 5):\n            z = random.randint(z_min, z_max)\n            y = random.randint(y_min, y_max)\n            x = random.randint(x_min, x_max)\n            \n            cube = lung_img[z:z+CUBE_SIZE, y:y+CUBE_SIZE, x:x+CUBE_SIZE]\n            cube_mask = lung_mask[z:z+CUBE_SIZE, y:y+CUBE_SIZE, x:x+CUBE_SIZE]\n            \n            if is_valid_negative_sample(cube, cube_mask, nodule_coords):\n                # 提取特征（使用简单的统计特征）\n                mean_val = np.mean(cube)\n                std_val = np.std(cube)\n                texture = texture_score(cube)\n                \n                all_candidates.append({\n                    'cube': cube,\n                    'position': (z, y, x),\n                    'features': [mean_val, std_val, texture],\n                    'score': texture\n                })\n        \n        if len(all_candidates) < num_samples:\n            print(f\"警告: 只找到 {len(all_candidates)} 个候选样本，少于请求的 {num_samples} 个\")\n            return [c['cube'] for c in all_candidates]\n        \n        # 提取特征矩阵\n        features = np.array([c['features'] for c in all_candidates])\n        \n        # 标准化特征\n        features = (features - features.mean(axis=0)) / features.std(axis=0)\n        \n        # 使用KMeans聚类\n        n_clusters = min(num_samples, len(all_candidates))\n        kmeans = KMeans(n_clusters=n_clusters, random_state=42)\n        clusters = kmeans.fit_predict(features)\n        \n        # 从每个簇中选择最佳样本\n        selected_samples = []\n        for i in range(n_clusters):\n            cluster_samples = [all_candidates[j] for j in range(len(all_candidates)) if clusters[j] == i]\n            if cluster_samples:\n                best_sample = max(cluster_samples, key=lambda x: x['score'])\n                selected_samples.append(best_sample)\n        \n        return [sample['cube'] for sample in selected_samples]\n    \n    else:\n        raise ValueError(f\"不支持的选择策略: {strategy}\")\n\ndef generate_negative_samples(ct_paths, output_dir, nodules_csv=None, samples_per_ct=50):\n    \"\"\"\n    从多个CT数据生成负样本\n    \n    Args:\n        ct_paths: CT文件或目录路径列表\n        output_dir: 输出目录\n        nodules_csv: 包含已知结节信息的CSV文件路径\n        samples_per_ct: 每个CT数据生成的负样本数量\n    \"\"\"\n    os.makedirs(output_dir, exist_ok=True)\n    \n    # 加载已知结节信息\n    nodule_coords_by_patient = {}\n    if nodules_csv and os.path.exists(nodules_csv):\n        import pandas as pd\n        nodules_df = pd.read_csv(nodules_csv)\n        for _, row in nodules_df.iterrows():\n            patient_id = row['patient_id']\n            if patient_id not in nodule_coords_by_patient:\n                nodule_coords_by_patient[patient_id] = []\n            nodule_coords_by_patient[patient_id].append(\n                (row['voxel_x'], row['voxel_y'], row['voxel_z'])\n            )\n    \n    # 处理每个CT数据\n    for ct_path in ct_paths:\n        try:\n            # 获取患者ID\n            if os.path.isfile(ct_path):\n                patient_id = os.path.splitext(os.path.basename(ct_path))[0]\n            else:\n                patient_id = os.path.basename(ct_path)\n            \n            print(f\"处理患者 {patient_id}...\")\n            \n            # 加载CT数据\n            if os.path.isfile(ct_path) and ct_path.endswith('.mhd'):\n                ct_data = CTData.from_mhd(ct_path)\n            elif os.path.isdir(ct_path):\n                ct_data = CTData.from_dicom(ct_path)\n            else:\n                print(f\"跳过不支持的文件类型: {ct_path}\")\n                continue\n            \n            # 获取已知结节坐标\n            nodule_coords = nodule_coords_by_patient.get(patient_id, None)\n            \n            # 选择负样本\n            negative_samples = select_negative_samples_from_ct(\n                ct_data, \n                nodule_coords=nodule_coords,\n                num_samples=samples_per_ct,\n                strategy='random'  # 或 'kmeans'\n            )\n            \n            # 保存负样本\n            for i, sample in enumerate(negative_samples):\n                output_path = os.path.join(output_dir, f\"{patient_id}_neg_{i:03d}.npy\")\n                np.save(output_path, sample)\n            \n            print(f\"已为患者 {patient_id} 生成 {len(negative_samples)} 个负样本\")\n            \n        except Exception as e:\n            print(f\"处理 {ct_path} 时出错: {str(e)}\")\n    \n    print(\"负样本生成完成!\")\n\ndef visualize_samples(samples, output_path=None, cols=5):\n    \"\"\"可视化样本立方体，用于质量检查\"\"\"\n    rows = (len(samples) + cols - 1) // cols\n    fig, axes = plt.subplots(rows, cols, figsize=(cols*3, rows*3))\n    axes = axes.flatten()\n    \n    for i, sample in enumerate(samples):\n        if i < len(axes):\n            # 显示立方体中间切片\n            middle_slice = sample[sample.shape[0]//2]\n            axes[i].imshow(middle_slice, cmap='gray')\n            axes[i].set_title(f\"Sample {i}\")\n            axes[i].axis('off')\n    \n    # 隐藏多余的子图\n    for i in range(len(samples), len(axes)):\n        axes[i].axis('off')\n    \n    plt.tight_layout()\n    \n    if output_path:\n        plt.savefig(output_path)\n        plt.close()\n    else:\n        plt.show()\n\nif __name__ == \"__main__\":\n    import argparse\n    \n    parser = argparse.ArgumentParser(description='负样本选择工具')\n    parser.add_argument('--input', type=str, required=True, help='输入CT文件或目录，或包含多个CT路径的文本文件')\n    parser.add_argument('--output', type=str, required=True, help='输出目录')\n    parser.add_argument('--nodules', type=str, default=None, help='包含已知结节信息的CSV文件')\n    parser.add_argument('--num-samples', type=int, default=50, help='每个CT数据生成的负样本数量')\n    parser.add_argument('--strategy', type=str, default='random', choices=['random', 'kmeans'], help='样本选择策略')\n    \n    args = parser.parse_args()\n    \n    # 处理输入参数\n    if os.path.isfile(args.input) and args.input.endswith('.txt'):\n        # 从文本文件读取CT路径列表\n        with open(args.input, 'r') as f:\n            ct_paths = [line.strip() for line in f if line.strip()]\n    else:\n        # 单个CT路径\n        ct_paths = [args.input]\n    \n    # 生成负样本\n    generate_negative_samples(\n        ct_paths,\n        args.output,\n        nodules_csv=args.nodules,\n        samples_per_ct=args.num_samples\n    ) "
  },
  {
    "path": "inference/pytorch_nodule_detector.py",
    "content": "import os\nimport numpy as np\nimport pandas as pd\nimport torch\nfrom torch.nn import functional as F\nfrom datetime import datetime\nimport logging\nimport time\nfrom scipy import ndimage\nfrom data.dataclass.CTData import CTData\nfrom data.dataclass.NoduleCube import normal_cube_to_tensor\nfrom data.preprocessing.luna16_invalid_nodule_filter import nodule_valid\nfrom models.pytorch_c3d_tiny import C3dTiny\n\n# 推理参数\nCUBE_SIZE = 32  # 扫描立方体大小 32x32x32\nSCAN_STEP = 10  # 扫描步长，每次移动10个像素\nPROB_THRESHOLD = 0.8  # 阈值: 大于此概率才视为结节\n\n# 设置日志\ndef setup_logger(log_dir=\"./inference_logs\"):\n    \"\"\"设置日志配置\"\"\"\n    os.makedirs(log_dir, exist_ok=True)\n    log_file = os.path.join(log_dir, f\"inference_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log\")\n    \n    # 创建logger\n    logger = logging.getLogger('nodule_detection')\n    logger.setLevel(logging.INFO)\n    \n    # 创建文件处理器\n    file_handler = logging.FileHandler(log_file)\n    file_handler.setLevel(logging.INFO)\n    \n    # 创建控制台处理器\n    console_handler = logging.StreamHandler()\n    console_handler.setLevel(logging.INFO)\n    \n    # 创建格式器\n    formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')\n    file_handler.setFormatter(formatter)\n    console_handler.setFormatter(formatter)\n    \n    # 添加处理器\n    logger.addHandler(file_handler)\n    logger.addHandler(console_handler)\n    \n    return logger\n\ndef load_ct_data(file_path):\n    \"\"\"\n        加载CT数据（支持MHD和DICOM）并进行预处理\n    Args:\n        file_path: CT文件或文件夹路径\n    Returns:\n        CTData对象\n    \"\"\"\n    # 判断是文件还是目录\n    if os.path.isfile(file_path):\n        # 假设是MHD文件\n        if file_path.endswith('.mhd'):\n            ct_data = CTData.from_mhd(file_path)\n        else:\n            raise ValueError(f\"不支持的文件类型: {file_path}\")\n    elif os.path.isdir(file_path):\n        # 假设是DICOM文件夹\n        ct_data = CTData.from_dicom(file_path)\n    else:\n        raise ValueError(f\"指定路径不存在: {file_path}\")\n    # 重采样到1mm间距\n    ct_data = ct_data.resample_pixel(new_spacing=[1, 1, 1])\n    # 肺部区域分割\n    ct_data.filter_lung_img_mask()\n    return ct_data\n\ndef load_model(evaL_model_path, device='cuda'):\n    \"\"\"\n    加载PyTorch模型\n    \n    Args:\n        evaL_model_path: 模型权重文件路径\n        device: 计算设备 ('cuda' 或 'cpu')\n\n    Returns:\n        加载好权重的模型\n    \"\"\"\n    model = C3dTiny().to(device)\n    # 加载权重\n    model.load_state_dict(torch.load(evaL_model_path, map_location=device))\n    model.eval()\n    return model, device\n\ndef get_lung_bounds(lung_mask):\n    \"\"\"获取肺部掩码的边界框，考虑左右肺分离的情况\"\"\"\n    if lung_mask.sum() == 0:\n        return None\n    # 使用连通区域分析找出肺部区域\n    labeled_mask, num_features = ndimage.label(lung_mask > 0)\n    # 如果连通区域过多，只考虑最大的几个区域（通常是左右肺）\n    if num_features > 2:\n        # 计算每个标签区域的体素数量\n        region_sizes = np.array([(labeled_mask == i).sum() for i in range(1, num_features + 1)])\n        # 只保留最大的2个区域（左右肺）\n        valid_labels = np.argsort(region_sizes)[-2:] + 1\n        # 创建新的掩码，只包含最大的几个区域\n        refined_mask = np.zeros_like(labeled_mask)\n        for label in valid_labels:\n            refined_mask[labeled_mask == label] = 1\n    else:\n        refined_mask = lung_mask > 0\n    \n    # 根据z轴切片，计算每个切片的肺部区域\n    z_ranges = []\n    margin = 5  # 切片边距\n    \n    # 遍历每个z轴切片\n    for z in range(refined_mask.shape[0]):\n        slice_mask = refined_mask[z]\n        if slice_mask.sum() > 100:  # 如果切片包含足够的肺部体素\n            y_indices, x_indices = np.where(slice_mask)\n            if len(y_indices) > 0:\n                y_min = max(0, y_indices.min() - margin)\n                y_max = min(refined_mask.shape[1], y_indices.max() + margin)\n                x_min = max(0, x_indices.min() - margin)\n                x_max = min(refined_mask.shape[2], x_indices.max() + margin)\n                z_ranges.append((z, y_min, y_max, x_min, x_max))\n    \n    if not z_ranges:\n        return None\n    \n    # 确定整体z轴范围\n    z_min = z_ranges[0][0]\n    z_max = z_ranges[-1][0] + 1\n    \n    # 收集所有y和x范围\n    scan_regions = []\n    for z_slice, y_min, y_max, x_min, x_max in z_ranges:\n        scan_regions.append({\n            'z': z_slice,\n            'y_min': y_min,\n            'y_max': y_max,\n            'x_min': x_min,\n            'x_max': x_max\n        })\n    \n    return {\n        'z_min': z_min,\n        'z_max': z_max,\n        'regions': scan_regions\n    }\n\ndef scan_ct_data(ct_data, model, device, logger, step=SCAN_STEP):\n    \"\"\"\n    扫描整个CT图像，预测结节位置 - 优化版\n    \n    Args:\n        ct_data: CTData对象\n        model: PyTorch模型\n        device: 计算设备\n        logger: 日志对象\n        step: 扫描步长\n        \n    Returns:\n        包含结节信息的DataFrame\n    \"\"\"\n    logger.info(\"开始扫描CT数据...\")\n    \n    # 获取肺部分割后的图像数据\n    lung_img = ct_data.lung_seg_img\n    lung_mask = ct_data.lung_seg_mask\n    \n    # 获取肺部边界信息\n    bounds = get_lung_bounds(lung_mask)\n    if bounds is None:\n        logger.warning(\"未能找到有效的肺部区域\")\n        return pd.DataFrame(columns=['voxel_coord_x', 'voxel_coord_y', 'voxel_coord_z', \n                                    'world_coord_x', 'world_coord_y', 'world_coord_z', 'prob'])\n    \n    logger.info(f\"已确定肺部区域: Z轴范围 {bounds['z_min']} 到 {bounds['z_max']}, 共 {len(bounds['regions'])} 个切片\")\n    \n    # 创建存储结果的列表\n    results = []\n    \n    # 计算需要扫描的总体素数估计\n    total_voxels = 0\n    for region in bounds['regions']:\n        y_range = region['y_max'] - region['y_min']\n        x_range = region['x_max'] - region['x_min']\n        total_voxels += (y_range // step + 1) * (x_range // step + 1)\n    \n    logger.info(f\"预计扫描体素数: {total_voxels}\")\n    \n    # 开始计时\n    start_time = time.time()\n    batch_size = 32  # 增大批处理大小提高GPU利用率\n    batch_inputs = []\n    batch_positions = []\n    # 跟踪进度\n    processed_voxels = 0\n    skipped_voxels = 0\n    # 设置肺部组织比例阈值\n    lung_tissue_threshold = 0.1  # 立方体中肺部组织的最小比例\n    # 逐切片扫描肺部区域\n    for z_idx, region in enumerate(bounds['regions']):\n        z = region['z']\n        # 检查是否可以放置一个完整的立方体\n        if z + CUBE_SIZE > lung_img.shape[0]:\n            continue\n        # 在当前切片上扫描\n        for y in range(region['y_min'], region['y_max'] - CUBE_SIZE + 1, step):\n            for x in range(region['x_min'], region['x_max'] - CUBE_SIZE + 1, step):\n                # 提取当前位置的肺部掩码立方体\n                mask_cube = lung_mask[z:z+CUBE_SIZE, y:y+CUBE_SIZE, x:x+CUBE_SIZE]\n                # 计算肺部组织比例\n                lung_ratio = np.mean(mask_cube)\n                # 如果肺部组织比例过低，跳过\n                if lung_ratio < lung_tissue_threshold:\n                    skipped_voxels += 1\n                    continue\n                # 提取当前位置的立方体\n                cube = lung_img[z:z+CUBE_SIZE, y:y+CUBE_SIZE, x:x+CUBE_SIZE]\n                # 预处理立方体数据\n                cube_tensor = normal_cube_to_tensor(cube)\n                # 转化为 batch\n                cube_tensor = cube_tensor.unsqueeze(0)\n                # 添加到批处理\n                batch_inputs.append(cube_tensor)\n                batch_positions.append((z, y, x))\n                # 当批处理达到指定大小时进行预测\n                if len(batch_inputs) == batch_size:\n                    # 处理当前批次\n                    process_batch(batch_inputs, batch_positions, model, device, ct_data, results)\n                    batch_inputs = []\n                    batch_positions = []\n                processed_voxels += 1\n                # 定期报告进度\n                if (processed_voxels + skipped_voxels) % 1000 == 0:\n                    elapsed_time = time.time() - start_time\n                    progress = processed_voxels / total_voxels * 100 if total_voxels > 0 else 0\n                    logger.info(f\"处理进度: {processed_voxels}/{total_voxels} ({progress:.2f}%), \"\n                                f\"已跳过: {skipped_voxels}, 耗时: {elapsed_time:.2f}秒\")\n    \n    # 处理最后一个批次\n    if batch_inputs:\n        process_batch(batch_inputs, batch_positions, model, device, ct_data, results)\n    \n    # 创建DataFrame\n    if results:\n        results_df = pd.DataFrame(results)\n        logger.info(f\"扫描完成! 发现 {len(results_df)} 个可能的结节\")\n    else:\n        results_df = pd.DataFrame(columns=['voxel_coord_x', 'voxel_coord_y', 'voxel_coord_z', \n                                          'world_coord_x', 'world_coord_y', 'world_coord_z', 'prob'])\n        logger.info(\"扫描完成! 未发现任何结节\")\n    \n    return results_df\n\ndef process_batch(batch_inputs, batch_positions, model, device, ct_data, results):\n    \"\"\"处理一个批次的数据\"\"\"\n    # 合并批处理\n    batch_tensor = torch.cat(batch_inputs, dim=0).to(device)\n    # 预测\n    with torch.no_grad():\n        batch_outputs = model(batch_tensor)\n        batch_probs = F.softmax(batch_outputs, dim=1)[:, 1]  # 类别1的概率\n    # 处理每个预测结果\n    for i, prob in enumerate(batch_probs):\n        prob_value = prob.item()\n        if prob_value > PROB_THRESHOLD:\n            z_pos, y_pos, x_pos = batch_positions[i]\n            # 计算中心点坐标\n            center_z = z_pos + CUBE_SIZE // 2\n            center_y = y_pos + CUBE_SIZE // 2\n            center_x = x_pos + CUBE_SIZE // 2\n            # 将体素坐标转换为世界坐标 (mm)\n            world_coord = ct_data.voxel_to_world([center_x, center_y, center_z])\n            # 添加结果\n            results.append({\n                'voxel_coord_x': center_x,\n                'voxel_coord_y': center_y,\n                'voxel_coord_z': center_z,\n                'world_coord_x': world_coord[0],\n                'world_coord_y': world_coord[1],\n                'world_coord_z': world_coord[2],\n                'prob': prob_value\n            })\n\ndef reduce_overlapping_nodules(results_df, distance_threshold=15):\n    \"\"\"\n    合并重叠的结节预测，使用更严格的距离阈值\n    \n    Args:\n        results_df: 包含结节预测的DataFrame\n        distance_threshold: 合并的距离阈值(体素)\n        \n    Returns:\n        合并后的结节DataFrame\n    \"\"\"\n    if len(results_df) <= 1:\n        return results_df\n    # 按概率从高到低排序\n    sorted_df = results_df.sort_values('prob', ascending=False).reset_index(drop=True)\n    # 创建一个布尔掩码来标记要保留的行\n    keep_mask = np.ones(len(sorted_df), dtype=bool)\n    # 对每一行\n    for i in range(len(sorted_df)):\n        if not keep_mask[i]:\n            continue  # 如果此行已被标记为删除，则跳过\n        # 获取当前结节的坐标\n        current = sorted_df.iloc[i]\n        # 比较与其他所有结节的距离\n        for j in range(i + 1, len(sorted_df)):\n            if not keep_mask[j]:\n                continue  # 如果要比较的行已被标记为删除，则跳过\n            # 获取要比较的结节坐标\n            compare = sorted_df.iloc[j]\n            # 计算3D欧氏距离\n            distance = np.sqrt(\n                (current['voxel_coord_x'] - compare['voxel_coord_x']) ** 2 +\n                (current['voxel_coord_y'] - compare['voxel_coord_y']) ** 2 +\n                (current['voxel_coord_z'] - compare['voxel_coord_z']) ** 2\n            )\n            # 如果距离小于阈值，标记为删除\n            if distance < distance_threshold:\n                keep_mask[j] = False\n    # 应用掩码，仅保留未被标记为删除的行\n    reduced_df = sorted_df[keep_mask].reset_index(drop=True)\n    return reduced_df\n\ndef filter_false_positives(nodules_df, ct_data, max_nodules=10):\n    \"\"\"\n    基于解剖学和统计特征过滤假阳性结节\n    \n    Args:\n        nodules_df: 包含结节预测的DataFrame\n        ct_data: CTData对象\n        max_nodules: 每个患者允许的最大结节数量\n\n    Returns:\n        过滤后的结节DataFrame\n    \"\"\"\n    if nodules_df.empty:\n        return nodules_df\n    # 获取肺部掩码\n    lung_mask = ct_data.lung_seg_mask\n    # 1. 限制结节总数\n    if len(nodules_df) > max_nodules:\n        # 只保留概率最高的前N个结节\n        nodules_df = nodules_df.sort_values('prob', ascending=False).head(max_nodules)\n    # 2. 基于位置过滤\n    filtered_rows = []\n    for i, row in nodules_df.iterrows():\n        x, y, z = int(row['voxel_coord_x']), int(row['voxel_coord_y']), int(row['voxel_coord_z'])\n        this_nodule_valid = nodule_valid(ct_data, x, y, z)\n        if this_nodule_valid:\n            # 通过所有检查，保留此结节\n            filtered_rows.append(row)\n    # 创建新的DataFrame\n    filtered_df = pd.DataFrame(filtered_rows)\n    # 3. 基于概率再次过滤\n    # 如果概率低于阈值，移除\n    # high_prob_threshold = 0.95  # 高概率阈值\n    # filtered_df = filtered_df[filtered_df['prob'] >= high_prob_threshold]\n    return filtered_df\n\ndef format_results(results_df, ct_data, patient_id):\n    \"\"\"\n    格式化结果为最终输出的DataFrame\n    \n    Args:\n        results_df: 合并后的结节DataFrame\n        ct_data: CTData对象\n        patient_id: 患者ID\n        \n    Returns:\n        包含结节信息的最终DataFrame\n    \"\"\"\n    # 如果没有结节，返回空的DataFrame\n    if results_df.empty:\n        return pd.DataFrame(columns=['patient_id', 'nodule_id', 'voxel_x', 'voxel_y', 'voxel_z', \n                                     'world_x', 'world_y', 'world_z', 'diameter_mm', 'prob'])\n    # 创建最终结果列表\n    final_results = []\n    # 处理每个结节\n    for i, row in results_df.iterrows():\n        # 设置默认直径为CUBE_SIZE / 2\n        diameter_mm = CUBE_SIZE / 2\n        # 添加结果\n        final_results.append({\n            'patient_id': patient_id,\n            'nodule_id': i + 1,\n            'voxel_x': int(row['voxel_coord_x']),\n            'voxel_y': int(row['voxel_coord_y']),\n            'voxel_z': int(row['voxel_coord_z']),\n            'world_x': row['world_coord_x'],\n            'world_y': row['world_coord_y'],\n            'world_z': row['world_coord_z'],\n            'diameter_mm': diameter_mm,\n            'prob': row['prob']\n        })\n    \n    # 创建DataFrame\n    final_df = pd.DataFrame(final_results)\n    \n    return final_df\n\ndef detect_nodules(file_path, model_path, detect_patient_id=None, device='cuda'):\n    \"\"\"\n    主函数：对CT数据进行结节检测\n    \n    Args:\n        file_path: CT文件或文件夹路径\n        model_path: 模型权重文件路径\n        detect_patient_id: 患者ID，如果为None则使用文件名\n        device: 计算设备 ('cuda' 或 'cpu')\n        \n    Returns:\n        包含结节信息的DataFrame\n    \"\"\"\n    # 设置日志\n    logger = setup_logger()\n    # 如果患者ID为None，则使用文件名\n    if detect_patient_id is None:\n        if os.path.isfile(file_path):\n            detect_patient_id = os.path.splitext(os.path.basename(file_path))[0]\n        else:\n            detect_patient_id = os.path.basename(file_path)\n    logger.info(f\"开始处理患者 {detect_patient_id} 的CT数据\")\n    try:\n        # 加载CT数据\n        logger.info(f\"加载CT数据: {file_path}\")\n        ct_data = load_ct_data(file_path)\n        # 加载模型\n        logger.info(f\"加载模型: {model_path}\")\n        model, device = load_model(model_path, device)\n        # 扫描CT数据\n        results_df = scan_ct_data(ct_data, model, device, logger)\n        # 合并重叠结节\n        logger.info(\"合并重叠结节...\")\n        reduced_df = reduce_overlapping_nodules(results_df)\n        logger.info(f\"合并后的结节数量: {len(reduced_df)}\")\n        # 过滤假阳性\n        logger.info(\"过滤假阳性结节...\")\n        filtered_df = filter_false_positives(reduced_df, ct_data)\n        logger.info(f\"过滤后的结节数量: {len(filtered_df)}\")\n        # 格式化结果\n        final_df = format_results(filtered_df, ct_data, patient_id)\n        logger.info(f\"检测完成，找到 {len(final_df)} 个结节\")\n        return final_df\n    except Exception as e:\n        logger.error(f\"检测过程中出错: {str(e)}\", exc_info=True)\n        raise\n    \nif __name__ == \"__main__\":\n    test_mhd = \"H:/luna16/subset8/1.3.6.1.4.1.14519.5.2.1.6279.6001.149041668385192796520281592139.mhd\"\n    model_path = \"../training/pytorch_checkpoints/best_model.pth\"\n    threshold = 0.7\n    patient_id = \"1.3.6.1.4.1.14519.5.2.1.6279.6001.149041668385192796520281592139\"\n    detect_result_csv = \"./c3d_classify_result-%s.csv\" %patient_id\n    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n    # 运行检测\n    result_df = detect_nodules(test_mhd, model_path, None, device)\n    # 保存结果\n    result_df.to_csv(detect_result_csv, index=False, encoding=\"utf-8\")"
  },
  {
    "path": "models/__init__.py",
    "content": ""
  },
  {
    "path": "models/pytorch_c3d_tiny.py",
    "content": "import torch.nn as nn\nimport torchvision.transforms as transforms\n\n# my_tranform =transforms.Compose([\n#     # transforms.Resize((32,32,32)),\n#     transforms.ToTensor(),\n#     transforms.Normalize((0.5,0.5,0.5), (0.5, 0.5,0.5))\n# ])\n\n\nclass C3dTiny(nn.Module):\n    def __init__(self):\n        super().__init__()\n        # 第一个3d卷积组\n        self.conv_block1 = nn.Sequential(\n            nn.Conv3d(in_channels=1, kernel_size=3, padding = 1, out_channels=64),\n            # 原网络结构没有，新增的\n            nn.BatchNorm3d(64),\n            nn.ReLU(),\n            nn.MaxPool3d(kernel_size=(1,2,2), stride = (1,2,2))\n        )\n        #\n        self.conv_block2 = nn.Sequential(\n            nn.Conv3d(in_channels=64, kernel_size=3, padding = 1, out_channels=128),\n            # 原网络结构没有，新增的\n            nn.BatchNorm3d(128),\n            nn.ReLU(),\n            nn.MaxPool3d(kernel_size=2)\n        )\n        self.drop_out1 = nn.Dropout(0.2)\n        #\n        self.conv_block3 = nn.Sequential(\n            nn.Conv3d(in_channels = 128, kernel_size=3, padding = 1, out_channels=256),\n            nn.BatchNorm3d(256),\n            nn.ReLU(),\n            nn.Conv3d(in_channels=256, kernel_size=3, padding = 1, out_channels=256),\n            nn.BatchNorm3d(256),\n            nn.ReLU(),\n            nn.MaxPool3d(kernel_size=2)\n        )\n        self.drop_out2 = nn.Dropout(0.2)\n        #\n        self.conv_block4 = nn.Sequential(\n            nn.Conv3d(in_channels = 256, kernel_size = 3, padding = 1, out_channels=512),\n            nn.BatchNorm3d(512),\n            nn.ReLU(),\n            nn.Conv3d(in_channels = 512, kernel_size = 3, padding = 1, out_channels = 512),\n            nn.BatchNorm3d(512),\n            nn.ReLU(),\n            nn.MaxPool3d(kernel_size=2)\n        )\n        self.drop_out3 = nn.Dropout(0.2)\n        self.flatten = nn.Flatten()\n        #计算输入特征数量：\n        # 原始输入为32x32x32，经过pool1(1,2,2)后变为32x16x16\n        # 经过pool2(2,2,2)后变为16x8x8\n        # 经过pool3(2,2,2)后变为8x4x4\n        # 经过pool4(2,2,2)后变为4x2x2\n        # 因此最终特征图大小为4x2x2，通道数为512\n        self.fc1 = nn.Sequential(\n            nn.Linear(512 * 4 * 2 * 2, 512),\n            nn.ReLU()\n        )\n        self.fc2 = nn.Linear(512, 2)\n\n    def forward(self, x):\n        x = self.conv_block1(x)\n        x = self.conv_block2(x)\n        x = self.drop_out1(x)\n        x = self.conv_block3(x)\n        x = self.drop_out2(x)\n        x = self.conv_block4(x)\n        x = self.drop_out3(x)\n        x = self.flatten(x)\n        x = self.fc1(x)\n        x = self.fc2(x)\n        return x"
  },
  {
    "path": "training/__init__.py",
    "content": ""
  },
  {
    "path": "training/pytorch_logs/training_20250331_223230.log",
    "content": "2025-03-31 22:32:30,594 - c3d_training - INFO - 训练集: 23640个样本\n2025-03-31 22:32:30,594 - c3d_training - INFO - 验证集: 5910个样本\n2025-03-31 22:32:30,749 - c3d_training - INFO - 模型结构:\nC3dTiny(\n  (conv_block1): Sequential(\n    (0): Conv3d(1, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))\n    (1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n    (2): ReLU()\n    (3): MaxPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2), padding=0, dilation=1, ceil_mode=False)\n  )\n  (conv_block2): Sequential(\n    (0): Conv3d(64, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))\n    (1): BatchNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n    (2): ReLU()\n    (3): MaxPool3d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n  )\n  (drop_out1): Dropout(p=0.2, inplace=False)\n  (conv_block3): Sequential(\n    (0): Conv3d(128, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))\n    (1): BatchNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n    (2): ReLU()\n    (3): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))\n    (4): BatchNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n    (5): ReLU()\n    (6): MaxPool3d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n  )\n  (drop_out2): Dropout(p=0.2, inplace=False)\n  (conv_block4): Sequential(\n    (0): Conv3d(256, 512, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))\n    (1): BatchNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n    (2): ReLU()\n    (3): Conv3d(512, 512, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))\n    (4): BatchNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n    (5): ReLU()\n    (6): MaxPool3d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n  )\n  (drop_out3): Dropout(p=0.2, inplace=False)\n  (flatten): Flatten(start_dim=1, end_dim=-1)\n  (fc1): Sequential(\n    (0): Linear(in_features=8192, out_features=512, bias=True)\n    (1): ReLU()\n  )\n  (fc2): Linear(in_features=512, out_features=2, bias=True)\n)\n2025-03-31 22:33:51,563 - c3d_training - INFO - Epoch [1/15], Train Loss: 0.3146, Train Acc: 89.05%, Val Loss: 0.1623, Val Acc: 93.28%, Time: 80.81s\n2025-03-31 22:33:51,652 - c3d_training - INFO - Epoch [1]: 保存最佳模型, 验证准确率: 93.28%\n2025-03-31 22:35:12,255 - c3d_training - INFO - Epoch [2/15], Train Loss: 0.0834, Train Acc: 97.17%, Val Loss: 0.0728, Val Acc: 97.33%, Time: 80.39s\n2025-03-31 22:35:12,341 - c3d_training - INFO - Epoch [2]: 保存最佳模型, 验证准确率: 97.33%\n2025-03-31 22:36:33,755 - c3d_training - INFO - Epoch [3/15], Train Loss: 0.0504, Train Acc: 98.30%, Val Loss: 0.0383, Val Acc: 98.61%, Time: 81.20s\n2025-03-31 22:36:33,843 - c3d_training - INFO - Epoch [3]: 保存最佳模型, 验证准确率: 98.61%\n2025-03-31 22:37:55,955 - c3d_training - INFO - Epoch [4/15], Train Loss: 0.0368, Train Acc: 98.80%, Val Loss: 0.0522, Val Acc: 98.22%, Time: 81.89s\n2025-03-31 22:39:16,050 - c3d_training - INFO - Epoch [5/15], Train Loss: 0.0282, Train Acc: 99.04%, Val Loss: 0.0322, Val Acc: 98.98%, Time: 79.84s\n2025-03-31 22:39:16,142 - c3d_training - INFO - Epoch [5]: 保存最佳模型, 验证准确率: 98.98%\n2025-03-31 22:40:37,157 - c3d_training - INFO - Epoch [6/15], Train Loss: 0.0244, Train Acc: 99.13%, Val Loss: 0.0393, Val Acc: 98.83%, Time: 80.80s\n2025-03-31 22:41:59,029 - c3d_training - INFO - Epoch [7/15], Train Loss: 0.0204, Train Acc: 99.40%, Val Loss: 0.0383, Val Acc: 98.95%, Time: 81.64s\n2025-03-31 22:43:20,016 - c3d_training - INFO - Epoch [8/15], Train Loss: 0.0219, Train Acc: 99.31%, Val Loss: 0.0578, Val Acc: 98.29%, Time: 80.75s\n2025-03-31 22:44:41,369 - c3d_training - INFO - Epoch [9/15], Train Loss: 0.0171, Train Acc: 99.52%, Val Loss: 0.0342, Val Acc: 99.17%, Time: 81.13s\n2025-03-31 22:44:41,473 - c3d_training - INFO - Epoch [9]: 保存最佳模型, 验证准确率: 99.17%\n2025-03-31 22:46:02,745 - c3d_training - INFO - Epoch [10/15], Train Loss: 0.0173, Train Acc: 99.49%, Val Loss: 0.0279, Val Acc: 99.15%, Time: 81.04s\n2025-03-31 22:47:24,128 - c3d_training - INFO - Epoch [11/15], Train Loss: 0.0176, Train Acc: 99.47%, Val Loss: 0.0349, Val Acc: 99.17%, Time: 81.15s\n2025-03-31 22:48:44,717 - c3d_training - INFO - Epoch [12/15], Train Loss: 0.0140, Train Acc: 99.53%, Val Loss: 0.0362, Val Acc: 99.27%, Time: 80.36s\n2025-03-31 22:48:44,807 - c3d_training - INFO - Epoch [12]: 保存最佳模型, 验证准确率: 99.27%\n2025-03-31 22:50:07,041 - c3d_training - INFO - Epoch [13/15], Train Loss: 0.0126, Train Acc: 99.62%, Val Loss: 0.0497, Val Acc: 98.87%, Time: 82.02s\n2025-03-31 22:51:27,896 - c3d_training - INFO - Epoch [14/15], Train Loss: 0.0152, Train Acc: 99.49%, Val Loss: 0.0271, Val Acc: 99.32%, Time: 80.61s\n2025-03-31 22:51:27,987 - c3d_training - INFO - Epoch [14]: 保存最佳模型, 验证准确率: 99.32%\n2025-03-31 22:52:49,409 - c3d_training - INFO - Epoch [15/15], Train Loss: 0.0127, Train Acc: 99.58%, Val Loss: 0.0277, Val Acc: 99.15%, Time: 81.20s\n2025-03-31 22:52:49,737 - c3d_training - INFO - 训练完成，最终模型已保存\n2025-03-31 22:52:50,003 - c3d_training - INFO - 训练完成!\n"
  },
  {
    "path": "training/train_c3d_pytorch.py",
    "content": "import os\nimport glob\nimport random\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\nfrom torch.utils.data import Dataset, DataLoader\nfrom torch.utils.tensorboard import SummaryWriter\nfrom sklearn.model_selection import train_test_split\nimport matplotlib.pyplot as plt\n\nfrom data.dataclass.NoduleCube import normal_cube_to_tensor\n\nplt.rcParams['font.family'] = ['SimHei']  # 设置字体为黑体\nplt.rcParams['axes.unicode_minus'] = False  # 正确显示负号\nimport logging\nimport time\nfrom datetime import datetime\nfrom models.pytorch_c3d_tiny import C3dTiny\n\n# 模型参数\nBATCH_SIZE = 64\nEPOCHS = 15\nLEARNING_RATE = 5e-4\nWEIGHT_DECAY = 1e-5\nCUBE_SIZE = 32\nDEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n\n# 设置日志\ndef setup_logger(log_dir=\"./pytorch_logs\"):\n    os.makedirs(log_dir, exist_ok=True)\n    log_file = os.path.join(log_dir, f\"training_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log\")\n    \n    # 创建logger\n    logger = logging.getLogger('c3d_training')\n    logger.setLevel(logging.INFO)\n    \n    # 创建文件处理器\n    file_handler = logging.FileHandler(log_file, encoding=\"utf-8\")\n    file_handler.setLevel(logging.INFO)\n    \n    # 创建控制台处理器\n    console_handler = logging.StreamHandler()\n    console_handler.setLevel(logging.INFO)\n    \n    # 创建格式器\n    formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')\n    file_handler.setFormatter(formatter)\n    console_handler.setFormatter(formatter)\n    \n    # 添加处理器\n    logger.addHandler(file_handler)\n    logger.addHandler(console_handler)\n    \n    return logger\n\nclass Luna16DataSet(Dataset):\n    def __init__(self, files, labels, tranform =None):\n        self.files = files\n        self.labels = labels\n        self.transform = tranform\n\n    def __len__(self):\n        return len(self.files)\n\n    def __getitem__(self, idx):\n        npy_file = self.files[idx]\n        item_label = self.labels[idx]\n        item_label = torch.tensor(item_label, dtype=torch.long)\n        item_data = np.load(npy_file)\n        torch_item_data = normal_cube_to_tensor(item_data)\n        if self.transform is not None:\n            torch_item_data = self.transform(torch_item_data)\n        return torch_item_data,item_label\n\ndef load_train_val_data(postive_dir, negative_dir):\n    \"\"\"\n        从文件夹加载训练集和验证集\n    :param postive_dir:\n    :param negative_dir:\n    :return:\n    \"\"\"\n    postive_files = glob.glob(os.path.join(postive_dir, \"*.npy\"))\n    negative_files = glob.glob(os.path.join(negative_dir, \"*.npy\"))\n    min_samples = min(len(postive_files) ,len(negative_files))\n\n    pos_files = random.sample(postive_files ,min_samples)\n    neg_files = random.sample(negative_files, 2*min_samples)\n    all_files = pos_files + neg_files\n    labels = np.concatenate([np.ones(len(pos_files)), np.zeros(len(neg_files))])\n    indices = np.arange(len(all_files))\n    np.random.shuffle(indices)\n\n    files_train,files_val,label_train,label_val = train_test_split(all_files, labels,test_size=0.2, random_state=42)\n    return files_train,files_val,label_train,label_val\n\ndef train_model(model,train_loader, val_loader, optimizer, criterion, scheduler, logger, writer, epoches, save_dir):\n    best_val_acc = 0.0\n    train_losses = []\n    val_losses = []\n    val_accs = []\n    train_accs = []\n    os.makedirs(save_dir, exist_ok= True)\n    for epoch in range(epoches):\n        model.train()\n        train_loss = 0.0\n        correct = 0\n        total = 0\n        start_time = time.time()\n        for i, (inputs, labels ) in enumerate(train_loader):\n            inputs = inputs.to(DEVICE)\n            labels = labels.to(DEVICE)\n            # 梯度清零\n            optimizer.zero_grad()\n            # 前向传播\n            outputs = model(inputs)\n            loss = criterion(outputs, labels)\n            \n            # 检查损失值是否为NaN\n            if torch.isnan(loss).any() or torch.isinf(loss).any():\n                logger.warning(f\"警告：损失值包含NaN或Inf，跳过此批次\")\n                continue\n                \n            # 后向传播loss\n            loss.backward()\n            \n            # 梯度裁剪，防止梯度爆炸\n            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)\n            \n            optimizer.step()\n            #\n            train_loss +=loss.item()\n            _,predicteds = outputs.max(1)\n            total +=labels.size(0)\n            correct +=predicteds.eq(labels).sum().item()\n            # 打印批次\n            if (i + 1) % 100 == 0:\n                print(f\"{epoch +1}/{epoches}, Batch [{i + 1}/ {len(train_loader)}], Loss: {loss.item():.4f}\")\n        # 本次 epoch 平均训练损失\n        epoch_train_loss = train_loss / len(train_loader)\n        # 本次epoch 平均准确率\n        epoch_train_acc = 100.0 * correct/total\n        train_losses.append(epoch_train_loss)\n        train_accs.append(epoch_train_acc)\n        # 计算平均训练损失和准确率\n\n        model.eval()\n        val_loss = 0.0\n        val_correct = 0\n        val_total = 0\n        with torch.no_grad():\n            for val_inputs,val_labels in val_loader:\n                val_inputs = val_inputs.to(DEVICE)\n                val_labels = val_labels.to(DEVICE)\n                #\n                val_outputs = model(val_inputs)\n                batch_val_loss = criterion(val_outputs, val_labels)\n                val_loss += batch_val_loss.item()\n                _,predicted  = val_outputs.max(1)\n                val_total += val_labels.size(0)\n                val_correct += predicted.eq(val_labels).sum().item()\n        # 计算平均验证损失和准确率\n        epoch_val_loss = val_loss / len(val_loader)\n        epoch_val_acc = 100.0 * val_correct / val_total\n        val_losses.append(epoch_val_loss)\n        val_accs.append(epoch_val_acc)\n        \n        # 更新学习率\n        if scheduler is not None:\n            scheduler.step(epoch_val_loss)\n            \n        # 记录到TensorBoard\n        writer.add_scalar('Loss/train', epoch_train_loss, epoch)\n        writer.add_scalar('Loss/val', epoch_val_loss, epoch)\n        writer.add_scalar('Accuracy/train', epoch_train_acc, epoch)\n        writer.add_scalar('Accuracy/val', epoch_val_acc, epoch)\n        writer.add_scalar('Learning_rate', optimizer.param_groups[0]['lr'], epoch)\n        \n        end_time = time.time()\n        epoch_time = end_time - start_time\n        \n        # 打印epoch信息\n        logger.info(f'Epoch [{epoch+1}/{epoches}], '\n                   f'Train Loss: {epoch_train_loss:.4f}, Train Acc: {epoch_train_acc:.2f}%, '\n                   f'Val Loss: {epoch_val_loss:.4f}, Val Acc: {epoch_val_acc:.2f}%, '\n                   f'Time: {epoch_time:.2f}s')\n        \n        # 保存最佳模型\n        if epoch_val_acc > best_val_acc:\n            best_val_acc = epoch_val_acc\n            torch.save(model.state_dict(), os.path.join(save_dir, 'best_model.pth'))\n            logger.info(f'Epoch [{epoch+1}]: 保存最佳模型, 验证准确率: {epoch_val_acc:.2f}%')\n        \n        # 每5个epoch保存一次检查点\n        # if (epoch + 1) % 5 == 0:\n        torch.save({\n            'epoch': epoch + 1,\n            'model_state_dict': model.state_dict(),\n            'optimizer_state_dict': optimizer.state_dict(),\n            'train_loss': epoch_train_loss,\n            'val_loss': epoch_val_loss,\n            'train_acc': epoch_train_acc,\n            'val_acc': epoch_val_acc\n        }, os.path.join(save_dir, f'checkpoint_epoch_{epoch+1}.pth'))\n    \n    # 保存最终模型\n    torch.save(model.state_dict(), os.path.join(save_dir, 'final_model.pth'))\n    logger.info(f'训练完成，最终模型已保存')\n    \n    # 绘制损失和准确率曲线\n    plot_metrics(train_losses, val_losses, train_accs, val_accs, save_dir)\n    \n    return train_losses, val_losses, train_accs, val_accs\n\ndef plot_metrics(train_losses, val_losses, train_accs, val_accs, save_dir):\n    \"\"\"绘制并保存损失和准确率曲线\"\"\"\n    plt.figure(figsize=(12, 5))\n    \n    # 绘制损失曲线\n    plt.subplot(1, 2, 1)\n    plt.plot(train_losses, label='Train Loss')\n    plt.plot(val_losses, label='Validation Loss')\n    plt.xlabel('Epochs')\n    plt.ylabel('Loss')\n    plt.title('Training and Validation Loss')\n    plt.legend()\n    plt.grid(True)\n    \n    # 绘制准确率曲线\n    plt.subplot(1, 2, 2)\n    plt.plot(train_accs, label='Train Accuracy')\n    plt.plot(val_accs, label='Validation Accuracy')\n    plt.xlabel('Epochs')\n    plt.ylabel('Accuracy (%)')\n    plt.title('Training and Validation Accuracy')\n    plt.legend()\n    plt.grid(True)\n    \n    plt.tight_layout()\n    plt.savefig(os.path.join(save_dir, 'training_metrics.png'))\n    plt.close()\n\ndef main():\n    # 设置随机种子\n    torch.manual_seed(42)\n    np.random.seed(42)\n    random.seed(42)\n    \n    # 创建目录\n    log_dir = \"./pytorch_logs\"\n    checkpoint_dir = \"./pytorch_checkpoints\"\n    os.makedirs(log_dir, exist_ok=True)\n    os.makedirs(checkpoint_dir, exist_ok=True)\n    \n    # 设置日志\n    logger = setup_logger(log_dir)\n    \n    # 设置TensorBoard\n    writer = SummaryWriter(log_dir=os.path.join(log_dir, 'tensorboard'))\n    \n    # 加载数据\n    pos_sample_dir = r\"J:\\luna16_processed\\positive_npys\"\n    neg_sample_dir = r\"J:\\luna16_processed\\negative_npys\"\n    files_train, files_val, labels_train, labels_val = load_train_val_data(pos_sample_dir, neg_sample_dir)\n    \n    logger.info(f\"训练集: {len(files_train)}个样本\")\n    logger.info(f\"验证集: {len(files_val)}个样本\")\n    \n    # 创建数据集和数据加载器\n    train_dataset = Luna16DataSet(files_train, labels_train)\n    val_dataset = Luna16DataSet(files_val, labels_val)\n    \n    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)\n    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)\n    \n    # 创建模型\n    model = C3dTiny().to(DEVICE)\n    logger.info(f\"模型结构:\\n{model}\")\n    \n    # 定义损失函数和优化器\n    criterion = nn.CrossEntropyLoss()\n    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY, eps=1e-8)\n    # 学习率调度器\n    scheduler = optim.lr_scheduler.ReduceLROnPlateau(\n        optimizer, mode='min', factor=0.5, patience=5, verbose=True, min_lr=1e-6\n    )\n    \n    # 训练模型\n    train_losses, val_losses, train_accs, val_accs = train_model(\n        model=model,\n        train_loader=train_loader,\n        val_loader=val_loader,\n        optimizer=optimizer,\n        criterion=criterion,\n        scheduler=scheduler,\n        logger=logger,\n        writer=writer,\n        epoches=EPOCHS,\n        save_dir=checkpoint_dir\n    )\n    \n    # 关闭TensorBoard writer\n    writer.close()\n    \n    logger.info(\"训练完成!\")\n\nif __name__ == \"__main__\":\n    main()\n\n"
  },
  {
    "path": "util/__init__.py",
    "content": ""
  },
  {
    "path": "util/dicom_util.py",
    "content": "import os\nimport glob\nimport pydicom\nimport numpy as np\nimport cv2\nfrom tqdm import tqdm\n\nfrom util.seg_util import get_segmented_lungs,normalize_hu_values\nfrom util.image_util import rescale_patient_images\n\n\ndef is_dicom_file(filename):\n    '''\n       if current file is a dicom file\n    :param filename:      file need to be judged\n    :return:\n    '''\n    file_stream = open(filename, 'rb')\n    file_stream.seek(128)\n    data = file_stream.read(4)\n    file_stream.close()\n    if data == b'DICM':\n        return True\n    return False\n\ndef get_dicom_thickness(dicom_slices):\n    \"\"\"\n        计算切片厚度\n    :param dicom_slices:    dicom 读取的 dicom数据\n    :return:\n    \"\"\"\n    if len(dicom_slices) > 1:\n        try:\n            slice_thickness = abs(dicom_slices[0].ImagePositionPatient[2] - dicom_slices[1].ImagePositionPatient[2])\n        except:\n            try:\n                slice_thickness = abs(dicom_slices[0].SliceLocation - dicom_slices[1].SliceLocation)\n            except:\n                # 如果无法计算，尝试从SliceThickness标签中获取\n                try:\n                    slice_thickness = float(dicom_slices[0].SliceThickness)\n                except:\n                    print(\"警告: 无法确定切片厚度，使用默认值1.0mm\")\n                    slice_thickness = 1.0\n    else:\n        try:\n            slice_thickness = float(dicom_slices[0].SliceThickness)\n        except:\n            print(\"警告: 只有一个切片，无法计算切片厚度，使用默认值1.0mm\")\n            slice_thickness = 1.0\n    return slice_thickness\n\ndef load_dicom_slices(dicom_path):\n    \"\"\"\n        load dicom file path and stack into list\n\n    :param dicom_path:     a dicom path\n    :return:            dicom list\n    \"\"\"\n    dicom_files = []\n    for root, _, files in os.walk(dicom_path):\n        for file in files:\n            if file.lower().endswith(('.dcm', '.dicom')):\n                real_file = os.path.join(dicom_path, root, file)\n                current_if_dicom = is_dicom_file(real_file)\n                if current_if_dicom:\n                    dicom_files.append(real_file)\n    if not dicom_files:\n        raise ValueError(f\"在路径 {dicom_path} 中未找到DICOM文件\")\n    # 加载所有切片\n    slices = []\n    for file in dicom_files:\n        try:\n            ds = pydicom.dcmread(file)\n            slices.append(ds)\n        except Exception as e:\n            print(f\"无法读取DICOM文件 {file}: {e}\")\n    # # 按照Z轴位置排序切片\n    slices.sort(key=lambda x: int(x.InstanceNumber))\n    slice_thickness = get_dicom_thickness(slices)\n    for s in slices:\n        s.SliceThickness = slice_thickness\n    return slices\n\ndef get_pixels_hu(slices):\n    '''\n        transfer dicom array to pixel array,and remove border(HU==-2000)\n\n    :param slices:  dicom list\n    :return:        pixel array of one patient's dicom\n    '''\n    image = np.stack([s.pixel_array for s in slices])\n    image = image.astype(np.int16)\n    image[image == -2000] = 0\n    for slice_number in range(len(slices)):\n        intercept = slices[slice_number].RescaleIntercept\n        slope = slices[slice_number].RescaleSlope\n        if slope != 1:\n            image[slice_number] = slope * image[slice_number].astype(np.float64)\n            image[slice_number] = image[slice_number].astype(np.int16)\n        image[slice_number] += np.int16(intercept)\n    return np.array(image, dtype=np.int16)\n\n\ndef getinfo_dicom(dicom_path):\n    print('dicom_path: ', dicom_path)\n    slices = load_dicom_slices(dicom_path)\n    print(type(slices[0]), slices[0].ImagePositionPatient)\n    print(len(slices), \"\\t\", slices[0].SliceThickness, \"\\t\", slices[0].PixelSpacing)\n    print(\"Orientation: \", slices[0].ImageOrientationPatient)\n    #assert slices[0].ImageOrientationPatient == [1.000000, 0.000000, 0.000000, 0.000000, 1.000000, 0.000000]\n    pixels = get_pixels_hu(slices)\n    image = pixels\n    print(image.shape)\n\n    invert_order = slices[1].ImagePositionPatient[2] > slices[0].ImagePositionPatient[2]\n    print(\"Invert order: \", invert_order, \" - \", slices[1].ImagePositionPatient[2], \",\",\n          slices[0].ImagePositionPatient[2])\n\n    pixel_spacing = slices[0].PixelSpacing\n    pixel_spacing.append(slices[0].SliceThickness)\n    # save dicom source image size\n    dicom_size = [image.shape[0], image.shape[1], image.shape[2]]\n\n    return pixel_spacing, dicom_size, invert_order\n\ndef extract_dicom_images_patient(dicom_path, target_dir):\n    slices = load_dicom_slices(dicom_path)\n    assert slices[0].ImageOrientationPatient == [1.000000, 0.000000, 0.000000, 0.000000, 1.000000, 0.000000]\n    pixels = get_pixels_hu(slices)\n    image = pixels\n    invert_order = slices[1].ImagePositionPatient[2] > slices[0].ImagePositionPatient[2]\n    pixel_spacing = slices[0].PixelSpacing\n    pixel_spacing.append(slices[0].SliceThickness)\n    # save dicom source image size\n    dicom_size = [image.shape[0], image.shape[1], image.shape[2]]\n    image = rescale_patient_images(image, pixel_spacing)\n    png_size = [image.shape[0], image.shape[1], image.shape[2]]\n    if not invert_order:\n        image = np.flipud(image)\n    if not os.path.exists(target_dir):\n        os.mkdir(target_dir)\n    else:\n        print(\"png dir already exists, return directly\")\n        return pixel_spacing, dicom_size, png_size, invert_order\n    png_files = glob.glob(target_dir + \"*.png\")\n    for file in png_files:\n        os.remove(file)\n    for i in tqdm(range(image.shape[0])):\n        img_path = patient_dir + \"/img_\" + str(i).rjust(4, '0') + \"_i.png\"\n        org_img = image[i]\n        img, mask = get_segmented_lungs(org_img.copy())\n        org_img = normalize_hu_values(org_img)\n        cv2.imwrite(img_path, org_img * 255)\n        cv2.imwrite(img_path.replace(\"_i.png\", \"_m.png\"), mask * 255)\n    return pixel_spacing, dicom_size, png_size,invert_order\n\n"
  },
  {
    "path": "util/image_util.py",
    "content": "from typing import Tuple\nimport cv2\nimport os\nimport numpy\nimport glob\nimport random\nimport numpy as np\nfrom scipy import ndimage\n\n\ndef get_normalized_img_unit8(img):\n    img = img.astype(numpy.float)\n    min = img.min()\n    max = img.max()\n    img -= min\n    img /= max - min\n    img *= 255\n    res = img.astype(numpy.uint8)\n    return res\n\n\ndef load_patient_images(png_path, wildcard=\"*.*\", exclude_wildcards=[]):\n    print(\"png path is\\t\",png_path)\n    src_dir = png_path\n    src_img_paths = glob.glob(src_dir +'/'+ wildcard)\n    for exclude_wildcard in exclude_wildcards:\n        exclude_img_paths = glob.glob(src_dir + exclude_wildcard)\n        src_img_paths = [im for im in src_img_paths if im not in exclude_img_paths]\n    src_img_paths.sort()\n\n    images = [cv2.imread(img_path, cv2.IMREAD_GRAYSCALE) for img_path in src_img_paths]\n    images = [im.reshape((1, ) + im.shape) for im in images]\n    res = numpy.vstack(images)\n    return res\n\n\ndef draw_overlay(png_path: str, p_x: float, p_y: float, p_z: float, index: str,  BOX_size:int = 20) -> None:\n    \"\"\"\n    在图像上绘制覆盖层\n    Args:\n        png_path: PNG图像路径\n        p_x: X坐标（百分比）\n        p_y: Y坐标（百分比）\n        p_z: Z坐标（百分比）\n        index: 索引标识\n        :param BOX_size:\n    \"\"\"\n    patient_img = load_patient_images(png_path + \"/png/\", \"*_i.png\", [])\n    z = int(p_z * patient_img.shape[0])\n    y = int(p_y * patient_img.shape[1])\n    x = int(p_x * patient_img.shape[2])\n     # 包围盒大小\n    x1 = x - BOX_size\n    y1 = y - BOX_size\n    x2 = x + BOX_size\n    y2 = y + BOX_size\n    target_img = patient_img[z, :, :]\n    cv2.rectangle(target_img, (x1, y1), (x2, y2), (255, 0, 0), 1)\n    cv2.imwrite(png_path + \"/\" + index + \".png\", target_img)\n\ndef prepare_image_for_net3D(img,MEAN_PIXEL_VALUE = 41):\n    '''\n        normalization of image (average and zero center)\n\n    :param img:               image to be normalization\n    :param MEAN_PIXEL_VALUE:\n    :return:\n    '''\n    img = img.astype(numpy.float32)\n    img -= MEAN_PIXEL_VALUE\n    img /= 255.\n    img = img.reshape(1, img.shape[0], img.shape[1], img.shape[2], 1)\n    return img\n\n\ndef move_png2dir(target_dir):\n    import shutil\n    first_dir = []\n    for path in os.listdir(target_dir):\n        if os.path.isdir(os.path.join(target_dir,path)):\n            first_dir.append(os.path.join(target_dir,path))\n    for d in first_dir:\n        tmp_path = []\n        for file in os.listdir(d):\n            tmp_file_path = os.path.join(d,file)\n            png_path = os.path.join(d,'png')\n            if not os.path.exists(png_path):\n                os.mkdir(png_path)\n            if tmp_file_path.endswith(\".png\"):\n                shutil.move(tmp_file_path,os.path.join(png_path,file))\n                print(\"move file from %s  to   %s \" %(tmp_file_path,os.path.join(png_path,file)))\n\ndef rescale_patient_images(images_zyx, org_spacing_xyz, target_voxel_mm =1.0, is_mask_image=False, verbose=False):\n    '''\n                rescale a 3D image to specified size\n\n    :param images_zyx:              source image\n    :param org_spacing_xyz:\n    :param target_voxel_mm:\n    :param is_mask_image:\n    :param verbose:\n    :return:\n    '''\n    if verbose:\n        print(\"Spacing: \", org_spacing_xyz)\n        print(\"Shape: \", images_zyx.shape)\n\n    # print \"Resizing dim z\"\n    resize_x = 1.0\n    resize_y = float(org_spacing_xyz[2]) / float(target_voxel_mm)\n    interpolation = cv2.INTER_NEAREST if is_mask_image else cv2.INTER_LINEAR\n    res = cv2.resize(images_zyx, dsize=None, fx=resize_x, fy=resize_y, interpolation=interpolation)  # opencv assumes y, x, channels umpy array, so y = z pfff\n    res = res.swapaxes(0, 2)\n    res = res.swapaxes(0, 1)\n    # print \"Shape: \", res.shape\n    resize_x = float(org_spacing_xyz[0]) / float(target_voxel_mm)\n    resize_y = float(org_spacing_xyz[1]) / float(target_voxel_mm)\n\n    # cv2 can handle max 512 channels..\n    if res.shape[2] > 512:\n        res = res.swapaxes(0, 2)\n        res1 = res[:256]\n        res2 = res[256:]\n        res1 = res1.swapaxes(0, 2)\n        res2 = res2.swapaxes(0, 2)\n        res1 = cv2.resize(res1, dsize=None, fx=resize_x, fy=resize_y, interpolation=interpolation)\n        res2 = cv2.resize(res2, dsize=None, fx=resize_x, fy=resize_y, interpolation=interpolation)\n        res1 = res1.swapaxes(0, 2)\n        res2 = res2.swapaxes(0, 2)\n        res = numpy.vstack([res1, res2])\n        res = res.swapaxes(0, 2)\n    else:\n        res = cv2.resize(res, dsize=None, fx=resize_x, fy=resize_y, interpolation=interpolation)\n    res = res.swapaxes(0, 2)\n    res = res.swapaxes(2, 1)\n    if verbose:\n        print(\"Shape after: \", res.shape)\n    return res\n\n\ndef rescale_patient_images2(images_zyx, target_shape, verbose=False):\n    if verbose:\n        print(\"Target: \", target_shape)\n        print(\"Shape: \", images_zyx.shape)\n\n    # print \"Resizing dim z\"\n    resize_x = 1.0\n    interpolation = cv2.INTER_NEAREST if False else cv2.INTER_LINEAR\n    res = cv2.resize(images_zyx, dsize=(target_shape[1], target_shape[0]), interpolation=interpolation)  # opencv assumes y, x, channels umpy array, so y = z pfff\n    # print \"Shape is now : \", res.shape\n\n    res = res.swapaxes(0, 2)\n    res = res.swapaxes(0, 1)\n\n    # cv2 can handle max 512 channels..\n    if res.shape[2] > 512:\n        res = res.swapaxes(0, 2)\n        res1 = res[:256]\n        res2 = res[256:]\n        res1 = res1.swapaxes(0, 2)\n        res2 = res2.swapaxes(0, 2)\n        res1 = cv2.resize(res1, dsize=(target_shape[2], target_shape[1]), interpolation=interpolation)\n        res2 = cv2.resize(res2, dsize=(target_shape[2], target_shape[1]), interpolation=interpolation)\n        res1 = res1.swapaxes(0, 2)\n        res2 = res2.swapaxes(0, 2)\n        res = numpy.vstack([res1, res2])\n        res = res.swapaxes(0, 2)\n    else:\n        res = cv2.resize(res, dsize=(target_shape[2], target_shape[1]), interpolation=interpolation)\n\n    res = res.swapaxes(0, 2)\n    res = res.swapaxes(2, 1)\n    if verbose:\n        print(\"Shape after: \", res.shape)\n    return res\n\n\ndef resize_image(image: np.ndarray, new_shape: Tuple[int, ...]) -> np.ndarray:\n    \"\"\"\n    调整图像大小\n\n    Args:\n        image: 输入图像\n        new_shape: 新形状\n\n    Returns:\n        np.ndarray: 调整大小后的图像\n    \"\"\"\n    # 处理单通道或多通道图像\n    if len(image.shape) == 3 and len(new_shape) == 2:\n        # 处理3D图像调整为2D\n        resized_image = np.zeros((image.shape[0], new_shape[0], new_shape[1]))\n        for i in range(image.shape[0]):\n            resized_image[i] = cv2.resize(image[i], (new_shape[1], new_shape[0]))\n        return resized_image\n    elif len(image.shape) == 2 and len(new_shape) == 2:\n        # 处理2D图像\n        return cv2.resize(image, (new_shape[1], new_shape[0]))\n    else:\n        # 处理任意维度图像\n        resize_factor = tuple(n / o for n, o in zip(new_shape, image.shape))\n        return ndimage.zoom(image, resize_factor, mode='nearest')\n\ndef cv_flip(img,cols,rows,degree):\n    '''\n        flip image by degree\n\n    :param img:         image array to be fliped\n    :param cols:        width of image\n    :param rows:        height of image\n    :param degree:      degree to flip\n    :return:\n    '''\n    M = cv2.getRotationMatrix2D((cols / 2, rows /2), degree, 1.0)\n    dst = cv2.warpAffine(img, M, (cols, rows))\n    return dst\n\n\ndef random_rotate_img(img, chance, min_angle, max_angle):\n    '''\n        random rotation an image\n\n    :param img:         image to be rotated\n    :param chance:      random probability\n    :param min_angle:   min angle to rotate\n    :param max_angle:   max angle to rotate\n    :return:            image after random rotated\n    '''\n    import cv2\n    if random.random() > chance:\n        return img\n    if not isinstance(img, list):\n        img = [img]\n\n    angle = random.randint(min_angle, max_angle)\n    center = (img[0].shape[0] / 2, img[0].shape[1] / 2)\n    rot_matrix = cv2.getRotationMatrix2D(center, angle, scale=1.0)\n\n    res = []\n    for img_inst in img:\n        img_inst = cv2.warpAffine(img_inst, rot_matrix, dsize=img_inst.shape[:2], borderMode=cv2.BORDER_CONSTANT)\n        res.append(img_inst)\n    if len(res) == 0:\n        res = res[0]\n    return res\n\n\ndef random_flip_img(img, horizontal_chance=0, vertical_chance=0):\n    '''\n        random flip image,both on horizontal and vertical\n\n    :param img:                 image to be flipped\n    :param horizontal_chance:   flip probability to flipped on horizontal direction\n    :param vertical_chance:     flip probability to flipped on vertical  direction\n    :return:                    image after flipped\n    '''\n    import cv2\n    flip_horizontal = False\n    if random.random() < horizontal_chance:\n        flip_horizontal = True\n\n    flip_vertical = False\n    if random.random() < vertical_chance:\n        flip_vertical = True\n\n    if not flip_horizontal and not flip_vertical:\n        return img\n\n    flip_val = 1\n    if flip_vertical:\n        flip_val = -1 if flip_horizontal else 0\n\n    if not isinstance(img, list):\n        res = cv2.flip(img, flip_val)  # 0 = X axis, 1 = Y axis,  -1 = both\n    else:\n        res = []\n        for img_item in img:\n            img_flip = cv2.flip(img_item, flip_val)\n            res.append(img_flip)\n    return res\n\n\ndef random_scale_img(img, xy_range, lock_xy=False):\n    if random.random() > xy_range.chance:\n        return img\n\n    if not isinstance(img, list):\n        img = [img]\n\n    import cv2\n    scale_x = random.uniform(xy_range.x_min, xy_range.x_max)\n    scale_y = random.uniform(xy_range.y_min, xy_range.y_max)\n    if lock_xy:\n        scale_y = scale_x\n\n    org_height, org_width = img[0].shape[:2]\n    xy_range.last_x = scale_x\n    xy_range.last_y = scale_y\n\n    res = []\n    for img_inst in img:\n        scaled_width = int(org_width * scale_x)\n        scaled_height = int(org_height * scale_y)\n        scaled_img = cv2.resize(img_inst, (scaled_width, scaled_height), interpolation=cv2.INTER_CUBIC)\n        if scaled_width < org_width:\n            extend_left = (org_width - scaled_width) / 2\n            extend_right = org_width - extend_left - scaled_width\n            scaled_img = cv2.copyMakeBorder(scaled_img, 0, 0, extend_left, extend_right, borderType=cv2.BORDER_CONSTANT)\n            scaled_width = org_width\n\n        if scaled_height < org_height:\n            extend_top = (org_height - scaled_height) / 2\n            extend_bottom = org_height - extend_top - scaled_height\n            scaled_img = cv2.copyMakeBorder(scaled_img, extend_top, extend_bottom, 0, 0, borderType=cv2.BORDER_CONSTANT)\n            scaled_height = org_height\n\n        start_x = (scaled_width - org_width) / 2\n        start_y = (scaled_height - org_height) / 2\n        tmp = scaled_img[start_y: start_y + org_height, start_x: start_x + org_width]\n        res.append(tmp)\n\n    return res\n\n\nclass XYRange:\n    def __init__(self, x_min, x_max, y_min, y_max, chance=1.0):\n        self.chance = chance\n        self.x_min = x_min\n        self.x_max = x_max\n        self.y_min = y_min\n        self.y_max = y_max\n        self.last_x = 0\n        self.last_y = 0\n\n    def get_last_xy_txt(self):\n        res = \"x_\" + str(int(self.last_x * 100)).replace(\"-\", \"m\") + \"-\" + \"y_\" + str(int(self.last_y * 100)).replace(\n            \"-\", \"m\")\n        return res\n\n\ndef random_translate_img(img, xy_range, border_mode=\"constant\"):\n    if random.random() > xy_range.chance:\n        return img\n    import cv2\n    if not isinstance(img, list):\n        img = [img]\n\n    org_height, org_width = img[0].shape[:2]\n    translate_x = random.randint(xy_range.x_min, xy_range.x_max)\n    translate_y = random.randint(xy_range.y_min, xy_range.y_max)\n    trans_matrix = numpy.float32([[1, 0, translate_x], [0, 1, translate_y]])\n\n    border_const = cv2.BORDER_CONSTANT\n    if border_mode == \"reflect\":\n        border_const = cv2.BORDER_REFLECT\n\n    res = []\n    for img_inst in img:\n        img_inst = cv2.warpAffine(img_inst, trans_matrix, (org_width, org_height), borderMode=border_const)\n        res.append(img_inst)\n    if len(res) == 1:\n        res = res[0]\n    xy_range.last_x = translate_x\n    xy_range.last_y = translate_y\n    return res\n\n\ndef data_augmentation(image: np.ndarray, augment_type: str = 'random') -> np.ndarray:\n    \"\"\"\n    对图像进行数据增强\n\n    Args:\n        image: 输入图像\n        augment_type: 增强类型，可选'random', 'flip', 'rotate', 'shift'\n\n    Returns:\n        np.ndarray: 增强后的图像\n    \"\"\"\n    if augment_type == 'random':\n        # 随机选择一种增强方式\n        augment_choices = ['flip', 'rotate', 'shift', 'none']\n        choice = np.random.choice(augment_choices)\n\n        if choice == 'flip':\n            return data_augmentation(image, 'flip')\n        elif choice == 'rotate':\n            return data_augmentation(image, 'rotate')\n        elif choice == 'shift':\n            return data_augmentation(image, 'shift')\n        else:\n            return image\n\n    elif augment_type == 'flip':\n        # 随机翻转\n        axis = np.random.randint(0, image.ndim)\n        return np.flip(image, axis=axis)\n\n    elif augment_type == 'rotate':\n        # 随机旋转\n        if image.ndim == 2:\n            angle = np.random.randint(0, 360)\n            return ndimage.rotate(image, angle, reshape=False, mode='nearest')\n        else:\n            # 3D旋转\n            axes = tuple(np.random.choice(range(image.ndim), size=2, replace=False))\n            angle = np.random.randint(0, 360)\n            return ndimage.rotate(image, angle, axes=axes, reshape=False, mode='nearest')\n\n    elif augment_type == 'shift':\n        # 随机平移\n        shift = np.random.randint(-5, 6, size=image.ndim)\n        return ndimage.shift(image, shift, mode='nearest')\n\n    return image"
  },
  {
    "path": "util/mhd_util.py",
    "content": "import os\nimport ntpath\nimport SimpleITK\nimport numpy as np\nimport pandas as pd\nimport cv2\n\nfrom data.dataclass.CTData import CTData\nfrom util.seg_util import normalize_hu_values,get_segmented_lungs\nfrom util import image_util\nfrom constant import tianchi\n\nTARGET_VOXEL_MM = 1.0\nMHD_INFO_HEAD = \"patient_id,shape_0,shape_1,shape_2,origin_x,origin_y,origin_z,direction_z(1_-1),\" \\\n                \"spacing_x,spacing_y,spacing_z,rescale_x,rescale_y,rescale_z\"\n\ndef get_all_mhd_file(BASE_DATA_DIR,base_head,max):\n    \"\"\"\n       get all mhd file list ,tianchi mhd file consist of train_subset00,train_subset01,... test_subset00,test_subset01,..\n\n    :param base_head:       'train' or 'test',or 'val', to construct train_subset00,test_subset01,val_subset02...\n    :param max:             the max suffix of path ,such as train_subset09, then max=09\n    :return:                all mhd file list\n    \"\"\"\n    mhd_files = []\n    for index in range(0,max):\n        if index<10:\n            index = \"0\"+str(index)\n        else:\n            index =str(index)\n        sub_path = os.path.join(BASE_DATA_DIR,base_head+\"_subset\"+index)\n        for name in os.listdir(sub_path):\n            if name.endswith(\".mhd\"):\n                mhd_files.append(os.path.join(sub_path,name))\n    return mhd_files\n\n\ndef get_luna16_mhd_file(mhd_root):\n    \"\"\"\n       get all mhd file list ,tianchi mhd file consist of train_subset00,train_subset01,... test_subset00,test_subset01,..\n\n    :param mhd_root:       'train' or 'test',or 'val', to construct train_subset00,test_subset01,val_subset02...\n    :return:                all mhd file list\n    \"\"\"\n    mhd_files = []\n    for root, _, files in os.walk(mhd_root):\n        for file in files:\n            if file.lower().endswith('.mhd'):\n                real_file = os.path.join(mhd_root, root, file)\n                mhd_files.append(real_file)\n    return mhd_files\ndef read_csv_to_pandas(mhd_info,col_sepator ='\\t'):\n    \"\"\"\n       read csv information into pandas dataframe\n\n    :param mhd_info:      csv file of mhd file\n    :param col_sepator:  sepator string of columns\n    :return:\n    \"\"\"\n    with open(mhd_info, 'r') as csv:\n        head = csv.readline().split(\",\")  # get header of csv\n        indexs = []\n        lines = csv.readlines()\n        list = []\n        for line in lines:\n            list.append(line.split(col_sepator))\n            indexs.append(line.split(col_sepator)[0])  #the first element should be id of patient\n        df = pd.DataFrame(data=list, columns=head,index=indexs)\n        return df\n\ndef extract_image_from_mhd(mhd_file_path,png_save_path_root =None):\n    \"\"\"\n        extract image from mhd file and return mhd information\n\n    :param mhd_file_path:       mhd file to extract\n    :param png_save_path_root:  file path where to save the extracted image (both image and mask image will be saved)\n                                ,if this param is None means only mhd information returns,no image extracted\n    :return:\n    \"\"\"\n    mhd_info = []\n    patient_id = ntpath.basename(mhd_file_path).replace(\".mhd\", \"\")\n    print(\"Patient: \", patient_id)\n    mhd_info.append(patient_id)\n    if not os.path.exists(png_save_path_root):\n        os.mkdir(png_save_path_root)\n    dst_dir = png_save_path_root+'/' + patient_id + \"/\"\n    if not os.path.exists(dst_dir):\n        os.mkdir(dst_dir)\n\n    itk_img = SimpleITK.ReadImage(mhd_file_path)\n    img_array = SimpleITK.GetArrayFromImage(itk_img)\n    print(\"Img array: \", img_array.shape)\n    (shape0,shape1,shape2) = img_array.shape\n    mhd_info.append(str(shape2))\n    mhd_info.append(str(shape1))\n    mhd_info.append(str(shape0))\n\n    origin = np.array(itk_img.GetOrigin())      # x,y,z  Origin in world coordinates (mm)\n    print(\"Origin (x,y,z): \", origin)\n    mhd_info.append(str(origin[0]))\n    mhd_info.append(str(origin[1]))\n    mhd_info.append(str(origin[2]))\n\n    direction = np.array(itk_img.GetDirection())      # x,y,z  Origin in world coordinates (mm)\n    print(\"Direction: \", direction)\n    direct_arow = [1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0]\n    if direction.tolist() == direct_arow:\n        print(\"positive direction..\")\n        mhd_info.append(str(1))\n    else:\n        mhd_info.append(str(-1))\n\n    spacing = np.array(itk_img.GetSpacing())    # spacing of voxels in world coor. (mm)\n    print(\"Spacing (x,y,z): \", spacing)\n    mhd_info.append(str(spacing[0]))\n    mhd_info.append(str(spacing[1]))\n    mhd_info.append(str(spacing[2]))\n\n    rescale = spacing /TARGET_VOXEL_MM\n    print(\"Rescale: \", rescale)\n    mhd_info.append(str(rescale[0]))\n    mhd_info.append(str(rescale[1]))\n    mhd_info.append(str(rescale[2]))\n\n    if png_save_path_root is None:              # get mhd information only\n        return mhd_info\n\n    if not os.path.exists(dst_dir):\n        if img_array.shape[1]== 512:\n            img_array = image_util.rescale_patient_images(img_array, spacing, TARGET_VOXEL_MM)\n        img_list = []\n        for i in range(img_array.shape[0]):\n            img = img_array[i]\n            seg_img, mask = get_segmented_lungs(img.copy())\n            img_list.append(seg_img)\n            img = normalize_hu_values(img)\n            cv2.imwrite(dst_dir + \"img_\" + str(i).rjust(4, '0') + \"_i.png\", img * 255)\n            cv2.imwrite(dst_dir + \"img_\" + str(i).rjust(4, '0') + \"_m.png\", mask * 255)\n    return mhd_info\n\n"
  },
  {
    "path": "util/progress_watch.py",
    "content": "import datetime\nclass Stopwatch(object):\n    def start(self):\n        self.start_time = Stopwatch.get_time()\n    def get_elapsed_time(self):\n        current_time = Stopwatch.get_time()\n        res = current_time - self.start_time\n        return res\n\n    def get_elapsed_seconds(self):\n        elapsed_time = self.get_elapsed_time()\n        res = elapsed_time.total_seconds()\n        return res\n\n    @staticmethod\n    def get_time():\n        res = datetime.datetime.now()\n        return res\n\n    @staticmethod\n    def start_new():\n        res = Stopwatch()\n        res.start()\n        return res\n\n\n\n\n"
  },
  {
    "path": "util/seg_util.py",
    "content": "import numpy as np\nimport matplotlib.pyplot as plt\nplt.rcParams['font.family'] = ['SimHei']  # 设置字体为黑体\nplt.rcParams['axes.unicode_minus'] = False  # 正确显示负号\nfrom scipy import ndimage as ndi\nfrom skimage.filters import roberts\nfrom skimage.measure import regionprops, label\nfrom skimage.morphology import binary_closing, disk, binary_erosion\nfrom skimage.segmentation import clear_border\n\n\ndef normalize_hu_values(image: np.ndarray, min_bound: int = -1000, max_bound: int = 400) -> np.ndarray:\n    \"\"\"\n    归一化HU值到[0,1]范围\n\n    Args:\n        image: 输入图像\n        min_bound: 最小HU值\n        max_bound: 最大HU值\n\n    Returns:\n        np.ndarray: 归一化后的图像\n    \"\"\"\n    image = (image - min_bound) / (max_bound - min_bound)\n    image[image > 1] = 1.\n    image[image < 0] = 0.\n    return image\n\ndef get_segmented_lungs(im, plot=False):\n    '''\n        extract lung ROI from pixel array\n\n    :param im:      a patient's piexl array\n    :param plot:    if plot when segment\n    :return:\n    '''\n    # Step 1: Convert into a binary image.\n    binary = im < -400\n    # Step 2: Remove the blobs connected to the border of the image.\n    cleared = clear_border(binary)\n    # Step 3: Label the image.\n    label_image = label(cleared)\n    # Step 4: Keep the labels with 2 largest areas.\n    areas = [r.area for r in regionprops(label_image)]\n    areas.sort()\n    if len(areas) > 2:\n        for region in regionprops(label_image):\n            if region.area < areas[-2]:\n                for coordinates in region.coords:\n                       label_image[coordinates[0], coordinates[1]] = 0\n    binary = label_image > 0\n    # Step 5: Erosion operation with a disk of radius 2. This operation is seperate the lung nodules attached to the blood vessels.\n    selem = disk(2)\n    binary = binary_erosion(binary, selem)\n    # Step 6: Closure operation with a disk of radius 10. This operation is    to keep nodules attached to the lung wall.\n    selem = disk(10) # CHANGE BACK TO 10\n    binary = binary_closing(binary, selem)\n    # Step 7: Fill in the small holes inside the binary mask of lungs.\n    edges = roberts(binary)\n    binary = ndi.binary_fill_holes(edges)\n    # Step 8: Superimpose the binary mask on the input image.\n    get_high_vals = binary == 0\n    im[get_high_vals] = -2000\n    if plot:\n        plt.figure(figsize=(10, 10))\n        plt.subplot(1, 2, 1)\n        plt.imshow(binary, cmap='gray')\n        plt.title('Lung Mask')\n        plt.subplot(1, 2, 2)\n        plt.imshow(im, cmap='gray')\n        plt.title('Masked Image')\n        plt.show()\n    return im, binary"
  }
]