', methods=['GET'])
def get_lung_slices_info(session_id):
"""获取肺部切片信息"""
print('/api/lung_slices_info 会话内容\t', SESSION_DATA, session_id)
# 检查会话是否存在
if session_id not in SESSION_DATA:
return jsonify({"success": False, "error": "会话不存在"}), 404
# 索引文件路径
index_path = os.path.join(UPLOAD_FOLDER, session_id, 'lung_slices', 'slices_index.json')
print('索引文件路径\t', index_path)
# 检查索引文件是否存在
if os.path.exists(index_path):
with open(index_path, 'r') as f:
slices_info = json.load(f)
return jsonify({"success": True, "slices_info": slices_info})
print('索引文件不存在,重新加载 npy 文件\t')
# 如果索引文件不存在,尝试从肺部分割数据创建基本信息
lung_seg_path = os.path.join(UPLOAD_FOLDER, session_id, 'lung_seg.npy')
if not os.path.exists(lung_seg_path):
return jsonify({"success": False, "error": "肺部分割数据不存在"}), 404
# 加载数据并创建基本信息
lung_seg = np.load(lung_seg_path)
# 获取各轴切片数量
z_slices = lung_seg.shape[0]
y_slices = lung_seg.shape[1]
x_slices = lung_seg.shape[2]
# 创建与save_lung_segmentation_slices函数相同格式的切片信息
slices_info = {
"dimensions": lung_seg.shape,
"z_slices": z_slices,
"y_slices": y_slices,
"x_slices": x_slices,
"z_axis": [],
"y_axis": [],
"x_axis": []
}
# 添加Z轴切片信息
for z in range(z_slices):
slices_info["z_axis"].append({
"index": z,
"filename": f"z_slice_{z:04d}.png",
"path": f"/api/lung_slice/{session_id}/z/{z}"
})
return jsonify({"success": True, "slices_info": slices_info})
# 启动服务器
if __name__ == '__main__':
# 确保目录存在
os.makedirs(UPLOAD_FOLDER, exist_ok=True)
# 启用自动检测
app.config['AUTO_DETECT'] = True
# 启动应用
app.run(host='0.0.0.0', port=5000, debug=True)
================================================
FILE: deploy/backend/data/c3d_nodule_detect.pth
================================================
[File too large to display: 67.5 MB]
================================================
FILE: deploy/backend/dataclass/CTData.py
================================================
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import numpy as np
import SimpleITK as sitk
from scipy import ndimage
from enum import Enum
import matplotlib.pyplot as plt
from util.dicom_util import load_dicom_slices, get_pixels_hu, get_dicom_thickness
from util.seg_util import get_segmented_lungs, normalize_hu_values
class CTFormat(Enum):
DICOM = 1
MHD = 2
UNKNOWN = 3
class CTData:
"""
统一的CT数据类,用于处理不同格式的CT图像数据
支持DICOM和MHD格式的加载、处理和分析
"""
def __init__(self):
# 基本属性
self.pixel_data = None # 像素数据,3D体素数组 (z,y,x)
self.lung_seg_img = None # 单独抽取肺部CT图像数据
self.lung_seg_mask = None # 肺部CT的掩码
self.origin = None # 坐标原点 (x,y,z),单位为mm
self.spacing = None # 体素间距 (x,y,z),单位为mm
self.orientation = None # 方向矩阵
self.z_axis_flip = False # z 轴是否是翻转的
self.size = None # 图像尺寸 (z,y,x)
self.data_format = None # 数据格式(DICOM/MHD)
self.metadata = {} # 其他元数据信息
self.hu_converted = False # 是否已转换为HU值
self.preprocessed = False # 数据是否已经处理过
@classmethod
def from_dicom(cls, dicom_path):
"""
从DICOM文件夹加载CT数据
Args:
dicom_path: DICOM文件夹路径
Returns:
CTData对象
"""
ct_data = cls()
ct_data.data_format = CTFormat.DICOM
slices = load_dicom_slices(dicom_path)
ct_data.pixel_data = get_pixels_hu(slices)
ct_data.z_axis_flip = slices[1].ImagePositionPatient[2] > slices[0].ImagePositionPatient[2]
ct_data.hu_converted = True
slice_thickness = get_dicom_thickness(slices)
# 设置像素间距
try:
ct_data.spacing = [
float(slices[0].PixelSpacing[0]),
float(slices[0].PixelSpacing[1]),
float(slice_thickness)
]
except:
print("警告: 无法获取像素间距,使用默认值[1.0, 1.0, 1.0]")
ct_data.spacing = [1.0, 1.0, 1.0]
# 设置原点
try:
ct_data.origin = [
float(slices[0].ImagePositionPatient[0]),
float(slices[0].ImagePositionPatient[1]),
float(slices[0].ImagePositionPatient[2])
]
except:
print("警告: 无法获取坐标原点,使用默认值[0.0, 0.0, 0.0]")
ct_data.origin = [0.0, 0.0, 0.0]
# 设置尺寸
ct_data.size = ct_data.pixel_data.shape
return ct_data
@classmethod
def from_mhd(cls, mhd_path):
"""
从MHD/RAW文件加载CT数据
Args:
mhd_path: MHD文件路径
Returns:
CTData对象
"""
ct_data = cls()
ct_data.data_format = CTFormat.MHD
try:
# 使用SimpleITK加载MHD文件
itk_img = sitk.ReadImage(mhd_path)
# 获取像素数据 (注意SimpleITK返回的数组顺序为z,y,x)
ct_data.pixel_data = sitk.GetArrayFromImage(itk_img)
# LUNA16的MHD数据已经是HU值
ct_data.hu_converted = True
# 获取原点和体素间距
ct_data.origin = list(itk_img.GetOrigin()) # (x,y,z)
ct_data.spacing = list(itk_img.GetSpacing()) # (x,y,z)
# 获取尺寸
ct_data.size = ct_data.pixel_data.shape
# 提取方向信息
ct_data.orientation = itk_img.GetDirection()
ct_data.z_axis_flip = False
except Exception as e:
raise ValueError(f"加载MHD文件时出错: {e}")
return ct_data
def convert_to_hu(self):
"""
将像素值转换为HU值(如果尚未转换)
"""
if self.hu_converted:
print("数据已经是HU值格式")
return
if self.data_format == CTFormat.DICOM:
# 已在from_dicom中处理
self.hu_converted = True
elif self.data_format == CTFormat.MHD:
# LUNA16的MHD数据已经是HU值
self.hu_converted = True
else:
raise ValueError("未知数据格式,无法转换为HU值")
def resample_pixel(self, new_spacing=[1, 1, 1]):
"""
将CT体素重采样为指定间距
Args:
new_spacing: 目标体素间距 [x, y, z]
Returns:
重采样后的CTData对象
"""
# 确保数据已转换为HU值
if not self.hu_converted:
self.convert_to_hu()
# 为了符合scipy.ndimage的要求,将spacing和pixel_data的顺序调整为[z,y,x]
spacing_zyx = [self.spacing[2], self.spacing[1], self.spacing[0]]
new_spacing_zyx = [new_spacing[2], new_spacing[1], new_spacing[0]]
# 计算新尺寸
resize_factor = np.array(spacing_zyx) / np.array(new_spacing_zyx)
new_shape = np.round(np.array(self.pixel_data.shape) * resize_factor)
# 计算实际重采样因子
real_resize = new_shape / np.array(self.pixel_data.shape)
# 执行重采样 - 使用三线性插值
resampled_data = ndimage.zoom(self.pixel_data, real_resize, order=1)
# 创建新的CTData对象
resampled_ct = CTData()
resampled_ct.pixel_data = resampled_data
resampled_ct.spacing = new_spacing
resampled_ct.origin = self.origin
resampled_ct.orientation = self.orientation
resampled_ct.size = resampled_data.shape
resampled_ct.data_format = self.data_format
resampled_ct.hu_converted = self.hu_converted
resampled_ct.preprocessed = self.preprocessed
return resampled_ct
def filter_lung_img_mask(self):
"""
只保留肺部区域像素,并且归一化到 0-1 之间
:return:
"""
pixel_data = self.pixel_data.copy()
seg_img = []
seg_mask = []
for index in range(pixel_data.shape[0]):
one_seg_img ,one_seg_mask = get_segmented_lungs(pixel_data[index])
one_seg_img = normalize_hu_values(one_seg_img)
seg_img.append(one_seg_img)
seg_mask.append(one_seg_mask)
self.lung_seg_img = np.array(seg_img)
self.lung_seg_mask = np.array(seg_mask)
def world_to_voxel(self, world_coord):
"""
将世界坐标(mm)转换为体素坐标
Args:
world_coord: 世界坐标 [x,y,z] (mm)
Returns:
体素坐标 [x,y,z]
"""
voxel_coord = np.zeros(3, dtype=int)
for i in range(3):
voxel_coord[i] = int(round((world_coord[i] - self.origin[i]) / self.spacing[i]))
return voxel_coord
def voxel_to_world(self, voxel_coord):
"""
将体素坐标转换为世界坐标(mm)
Args:
voxel_coord: 体素坐标 [x,y,z]
Returns:
世界坐标 [x,y,z] (mm)
"""
world_coord = np.zeros(3, dtype=float)
for i in range(3):
world_coord[i] = voxel_coord[i] * self.spacing[i] + self.origin[i]
return world_coord
def extract_cube(self, center_world_mm, size_mm,if_fixed_radius = False):
"""
提取指定中心点和大小的立方体区域
Args:
center_world_mm: 立方体中心的世界坐标 [x,y,z] (mm)
size_mm: 立方体在世界坐标系的大小(mm),可以是数值或[x,y,z]形式
if_fixed_radius: 是否为固定半径。默认是False(即不不是固定的,就说明每个结节半径都不一样,按照标注文件半径抽取)
Returns:
立方体像素数据
"""
# 确保数据已加载
if self.pixel_data is None:
raise ValueError("未加载数据")
if self.lung_seg_img is None:
print("肺部区域数据没有分割,现在开始分割..")
self.filter_lung_img_mask()
# 将世界坐标转换为体素坐标(注意:SimpleITK数组顺序为z,y,x)
center_voxel = self.world_to_voxel(center_world_mm)
# 交换坐标顺序为z,y,x以匹配pixel_data
center_voxel_zyx = [center_voxel[2], center_voxel[1], center_voxel[0]]
# 如果使用固定半径,那么只需要中心坐标即可,此时size_mm 就是像素半径了,直接从 lung_seg_img 按照像素半径抽取即可
if if_fixed_radius:
half_size = [int(size_mm/2), int(size_mm/2), int(size_mm/2)]
else:
# 计算立方体边长(体素数) [luna2016 的标注数据中每个结节半径不同,按照标注抽取的结节大小不一,最好使用固定半径]
size_voxel = [int(size_mm / self.spacing[2]),
int(size_mm / self.spacing[1]),
int(size_mm / self.spacing[0])]
# 计算立方体边界
half_size = [s // 2 for s in size_voxel]
# 提取立方体数据
z_min = max(0, center_voxel_zyx[0] - half_size[0])
y_min = max(0, center_voxel_zyx[1] - half_size[1])
x_min = max(0, center_voxel_zyx[2] - half_size[2])
z_max = min(self.lung_seg_img.shape[0], center_voxel_zyx[0] + half_size[0])
y_max = min(self.lung_seg_img.shape[1], center_voxel_zyx[1] + half_size[1])
x_max = min(self.lung_seg_img.shape[2], center_voxel_zyx[2] + half_size[2])
# 提取子体积
cube = self.lung_seg_img[z_min:z_max, y_min:y_max, x_min:x_max]
return cube
def visualize_slice(self, slice_idx=None, axis=0, show_lung_only=False):
"""
可视化单个切片
Args:
slice_idx: 切片索引,如果为None则取中心切片
axis: 沿哪个轴切片 (0=z, 1=y, 2=x)
show_lung_only: 是否只显示肺部,其他区域都作为背景黑色
"""
# 确保数据已加载
if self.pixel_data is None:
raise ValueError("未加载数据")
# 确定切片索引
if slice_idx is None:
slice_idx = self.pixel_data.shape[axis] // 2
# 提取切片数据
if show_lung_only:
if axis == 0: # z轴
slice_data = self.lung_seg_img[slice_idx, :, :]
elif axis == 1: # y轴
slice_data = self.lung_seg_img[:, slice_idx, :]
else : # x轴
slice_data = self.lung_seg_img[:, :, slice_idx]
else:
if axis == 0: # z轴
slice_data = self.pixel_data[slice_idx, :, :]
elif axis == 1: # y轴
slice_data = self.pixel_data[:, slice_idx, :]
else: # x轴
slice_data = self.pixel_data[:, :, slice_idx]
# 创建图像
plt.figure(figsize=(10, 8))
# 仅显示图像
plt.imshow(slice_data, cmap='gray')
# 设置标题
axis_name = ['z', 'y', 'x'][axis]
title = f"切片 {slice_idx} (沿{axis_name}轴)"
plt.title(title)
plt.colorbar(label='像素值')
plt.axis('off')
plt.tight_layout()
plt.show()
def visualize_nodule(self, coord_x,coord_y, coord_z, diameter):
"""
结节可视化
:param coord_x:
:param coord_y:
:param coord_z:
:param diameter:
:return:
"""
# 提取结节立方体
cube_size = max(32, int(diameter * 1.5)) # 确保立方体足够大
cube = self.extract_cube([coord_x, coord_y, coord_z], cube_size)
# 转换为体素坐标
voxel_coord = self.world_to_voxel([coord_x, coord_y, coord_z])
# 显示三个正交面
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
# 提取中心切片
center_z = cube.shape[0] // 2
center_y = cube.shape[1] // 2
center_x = cube.shape[2] // 2
# 绘制三个正交面
axes[0].imshow(cube[center_z, :, :], cmap='gray')
axes[0].set_title(f'轴向视图 (z={center_z})')
axes[0].axis('off')
axes[1].imshow(cube[:, center_y, :], cmap='gray')
axes[1].set_title(f'冠状位视图 (y={center_y})')
axes[1].axis('off')
axes[2].imshow(cube[:, :, center_x], cmap='gray')
axes[2].set_title(f'矢状位视图 (x={center_x})')
axes[2].axis('off')
fig.suptitle(f"结节- 位置: ({coord_x:.1f}, {coord_y:.1f}, {coord_z:.1f})mm, " +
f"直径: {diameter:.1f}mm,", fontsize=14)
plt.tight_layout()
plt.show()
def save_as_nifti(self, output_path):
"""
将CT数据保存为NIfTI格式
Args:
output_path: 输出文件路径
"""
# 确保数据已加载
if self.pixel_data is None:
raise ValueError("未加载数据")
# 创建SimpleITK图像
# 注意:SimpleITK的数组顺序为z,y,x
img = sitk.GetImageFromArray(self.pixel_data)
img.SetOrigin(self.origin)
img.SetSpacing(self.spacing)
if self.orientation is not None:
img.SetDirection(self.orientation)
# 保存为NIfTI格式
sitk.WriteImage(img, output_path)
print(f"已保存为NIfTI格式: {output_path}")
================================================
FILE: deploy/backend/dataclass/NoduleCube.py
================================================
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os,torch
import numpy as np
import cv2
from typing import Optional
from dataclasses import dataclass
import matplotlib.pyplot as plt
from scipy import ndimage
def normal_cube_to_tensor(cube_data):
"""
将cube 数据归一化并转换为 pytorch tensor 。 用在训练和推理过程
:param cube_data: shape为 [32,32,32] 的 ndarray
:return:
"""
cube_data = cube_data.astype(np.float32)
# 归一化到 [0, 1] 范围
min_val = np.min(cube_data)
max_val = np.max(cube_data)
data_range = max_val - min_val
# 避免除以零
if data_range < 1e-10:
normalized_cube = np.zeros_like(cube_data)
else:
normalized_cube = (cube_data - min_val) / data_range
# 检查是否有无效值并修复
if np.isnan(normalized_cube).any() or np.isinf(normalized_cube).any():
normalized_cube = np.nan_to_num(normalized_cube, nan=0.0, posinf=1.0, neginf=0.0)
# 转换为PyTorch张量并添加批次和通道维度
cube_tensor = torch.from_numpy(normalized_cube).float().unsqueeze(0).unsqueeze(0) # (1, 1, 32, 32, 32)
return cube_tensor
@dataclass
class NoduleCube:
"""
肺结节立方体类,表示肺结节区域的3D立方体数据
与CT数据无关,仅处理已提取的立方体数据
"""
# 基本属性
cube_size: int = 64 # 立方体大小(默认64x64x64)
pixel_data: Optional[np.ndarray] = None # 像素数据 shape: [cube_size, cube_size, cube_size]
# 结节特征
center_x: int = 0 # 结节中心x坐标
center_y: int = 0 # 结节中心y坐标
center_z: int = 0 # 结节中心z坐标
radius: float = 0.0 # 结节半径
malignancy: int = 0 # 恶性度 (0 为良性 / 1 为恶性)
# 文件路径
npy_path: str = "" # npy文件路径
png_path: str = "" # png文件路径
def __post_init__(self):
"""初始化后调用"""
# 如果提供了npy_path但没有pixel_data,尝试加载
if self.npy_path and self.pixel_data is None:
self.load_from_npy()
# 如果提供了png_path但没有pixel_data,尝试加载
elif self.png_path and self.pixel_data is None:
self.load_from_png()
def load_from_npy(self) -> None:
"""从NPY文件加载立方体数据"""
if not os.path.exists(self.npy_path):
raise FileNotFoundError(f"文件不存在: {self.npy_path}")
try:
self.pixel_data = np.load(self.npy_path)
# 验证尺寸
if len(self.pixel_data.shape) != 3:
raise ValueError(f"像素数据必须是3D数组,当前形状: {self.pixel_data.shape}")
# 如果尺寸不匹配,调整大小
if (self.pixel_data.shape[0] != self.cube_size or
self.pixel_data.shape[1] != self.cube_size or
self.pixel_data.shape[2] != self.cube_size):
self.resize(self.cube_size)
except Exception as e:
raise ValueError(f"加载NPY文件时出错: {e}")
def save_to_npy(self, output_path: str) -> str:
"""
将立方体数据保存为NPY文件
Args:
output_path: 输出路径
Returns:
保存的文件路径
"""
if self.pixel_data is None:
raise ValueError("没有像素数据可保存")
# 确保目录存在
os.makedirs(os.path.dirname(output_path), exist_ok=True)
np.save(output_path, self.pixel_data)
self.npy_path = output_path
return output_path
def save_to_png(self, output_path: str) -> str:
"""
将立方体数据保存为PNG图像(8x8网格布局)
Args:
output_path: 输出PNG文件路径
Returns:
保存的文件路径
"""
if self.pixel_data is None:
raise ValueError("没有像素数据可保存")
# 确保目录存在
os.makedirs(os.path.dirname(output_path), exist_ok=True)
# 计算每个切片在最终图像中的位置(8行8列布局)
rows, cols = 8, 8
if self.cube_size != 64:
# 如果不是64x64x64,计算合适的行列数,保持接近正方形
total_slices = self.cube_size
rows = int(np.sqrt(total_slices))
while total_slices % rows != 0:
rows -= 1
cols = total_slices // rows
# 创建拼接图像
img_height = self.cube_size
img_width = self.cube_size
combined_img = np.zeros((rows * img_height, cols * img_width), dtype=np.uint8)
# 填充拼接图像
for i in range(self.cube_size):
row = i // cols
col = i % cols
slice_data = self.pixel_data[i]
# 确保数据在0-255范围内
if slice_data.max() <= 1.0:
slice_data = (slice_data * 255).astype(np.uint8)
else:
slice_data = slice_data.astype(np.uint8)
# 将切片放入拼接图像
y_start = row * img_height
x_start = col * img_width
combined_img[y_start:y_start + img_height, x_start:x_start + img_width] = slice_data
# 保存拼接图像
cv2.imwrite(output_path, combined_img)
self.png_path = output_path
return output_path
def load_from_png(self) -> None:
"""从PNG图像加载立方体数据(8x8网格布局)"""
if not os.path.exists(self.png_path):
raise FileNotFoundError(f"文件不存在: {self.png_path}")
try:
# 读取PNG图像
img = cv2.imread(self.png_path, cv2.IMREAD_GRAYSCALE)
# 确定行列数
rows, cols = 8, 8
if self.cube_size != 64:
# 如果不是64x64x64,计算合适的行列数
total_slices = self.cube_size
rows = int(np.sqrt(total_slices))
while total_slices % rows != 0:
rows -= 1
cols = total_slices // rows
# 确认图像尺寸正确
expected_height = rows * self.cube_size
expected_width = cols * self.cube_size
if img.shape[0] != expected_height or img.shape[1] != expected_width:
raise ValueError(f"图像尺寸不匹配: 期望{expected_height}x{expected_width}, 实际{img.shape[0]}x{img.shape[1]}")
# 创建3D数组
cube_data = np.zeros((self.cube_size, self.cube_size, self.cube_size), dtype=np.float32)
# 从PNG图像提取每个切片
for i in range(self.cube_size):
row = i // cols
col = i % cols
y_start = row * self.cube_size
x_start = col * self.cube_size
slice_data = img[y_start:y_start + self.cube_size, x_start:x_start + self.cube_size]
cube_data[i] = slice_data.astype(np.float32) / 255.0 # 归一化到[0,1]范围
self.pixel_data = cube_data
except Exception as e:
raise ValueError(f"加载PNG文件时出错: {e}")
def set_cube_data(self, pixel_data: np.ndarray) -> None:
"""
设置立方体像素数据
Args:
pixel_data: 3D像素数据
"""
if len(pixel_data.shape) != 3:
raise ValueError(f"像素数据必须是3D数组,当前形状: {pixel_data.shape}")
self.pixel_data = pixel_data
# 如果尺寸不匹配,调整大小
if (self.pixel_data.shape[0] != self.cube_size or
self.pixel_data.shape[1] != self.cube_size or
self.pixel_data.shape[2] != self.cube_size):
self.resize(self.cube_size)
def resize(self, new_size: int) -> None:
"""
调整立方体尺寸
Args:
new_size: 新的立方体尺寸
"""
if self.pixel_data is None:
raise ValueError("没有像素数据可调整大小")
# 计算缩放因子
zoom_factors = [new_size / self.pixel_data.shape[0],
new_size / self.pixel_data.shape[1],
new_size / self.pixel_data.shape[2]]
# 使用scipy的ndimage进行重采样
self.pixel_data = ndimage.zoom(self.pixel_data, zoom_factors, mode='nearest')
self.cube_size = new_size
def augment(self, rotation: bool = True, flip_axis: int = -1, noise: bool = True) -> 'NoduleCube':
"""
数据增强
Args:
rotation: 是否进行旋转增强
flip_axis: 是否进行翻转增强,默认为-1(不翻转)
noise: 是否添加噪声
Returns:
增强后的新立方体实例
"""
if self.pixel_data is None:
raise ValueError("没有像素数据可增强")
# 创建副本
augmented_cube = self.pixel_data.copy()
# 旋转增强
if rotation:
# 随机选择旋转角度
angles = np.random.uniform(-20, 20, 3) # 在xyz三个方向上随机旋转
augmented_cube = ndimage.rotate(augmented_cube, angles[0], axes=(1, 2), reshape=False, mode='nearest')
augmented_cube = ndimage.rotate(augmented_cube, angles[1], axes=(0, 2), reshape=False, mode='nearest')
augmented_cube = ndimage.rotate(augmented_cube, angles[2], axes=(0, 1), reshape=False, mode='nearest')
# 翻转增强
if flip_axis >=0:
augmented_cube = np.flip(augmented_cube, axis=flip_axis)
# 添加噪声
if noise:
# 添加随机高斯噪声
noise_level = np.random.uniform(0.0, 0.05)
noise_array = np.random.normal(0, noise_level, augmented_cube.shape)
augmented_cube = augmented_cube + noise_array
# 确保值在[0,1]范围内
augmented_cube = np.clip(augmented_cube, 0, 1)
# 创建新实例
new_cube = NoduleCube(
cube_size=self.cube_size,
center_x=self.center_x,
center_y=self.center_y,
center_z=self.center_z,
radius=self.radius,
malignancy=self.malignancy
)
new_cube.set_cube_data(augmented_cube)
return new_cube
def visualize_3d(self, output_path: Optional[str] = None, show: bool = True) -> None:
"""
可视化立方体数据
Args:
output_path: 可选的输出路径,如果提供则保存图像
show: 是否显示图像
"""
if self.pixel_data is None:
raise ValueError("没有像素数据可视化")
# 创建图像
fig, axes = plt.subplots(2, 3, figsize=(12, 8))
# 获取中心切片
center_z = self.pixel_data.shape[0] // 2
center_y = self.pixel_data.shape[1] // 2
center_x = self.pixel_data.shape[2] // 2
# 显示三个正交平面
slice_xy = self.pixel_data[center_z, :, :]
slice_xz = self.pixel_data[:, center_y, :]
slice_yz = self.pixel_data[:, :, center_x]
# 显示三个正交视图
axes[0, 0].imshow(slice_xy, cmap='gray')
axes[0, 0].set_title(f'轴向视图 (Z={center_z})')
axes[0, 1].imshow(slice_xz, cmap='gray')
axes[0, 1].set_title(f'矢状位视图 (Y={center_y})')
axes[0, 2].imshow(slice_yz, cmap='gray')
axes[0, 2].set_title(f'冠状位视图 (X={center_x})')
# 3D渲染视图(使用MIP: Maximum Intensity Projection)
mip_xy = np.max(self.pixel_data, axis=0)
mip_xz = np.max(self.pixel_data, axis=1)
mip_yz = np.max(self.pixel_data, axis=2)
axes[1, 0].imshow(mip_xy, cmap='gray')
axes[1, 0].set_title('最大强度投影 (轴向)')
axes[1, 1].imshow(mip_xz, cmap='gray')
axes[1, 1].set_title('最大强度投影 (矢状位)')
axes[1, 2].imshow(mip_yz, cmap='gray')
axes[1, 2].set_title('最大强度投影 (冠状位)')
# 添加结节信息
nodule_info = f"结节中心: ({self.center_x}, {self.center_y}, {self.center_z})\n"
nodule_info += f"半径: {self.radius:.1f}\n"
nodule_info += f"恶性度: {'恶性' if self.malignancy == 1 else '良性'}"
fig.suptitle(nodule_info, fontsize=12)
plt.tight_layout()
if output_path:
plt.savefig(output_path, dpi=200, bbox_inches='tight')
if show:
plt.show()
else:
plt.close(fig)
@classmethod
def from_npy(cls, file_path: str, cube_size: int = 64) -> 'NoduleCube':
"""
从NPY文件创建立方体实例
Args:
file_path: NPY文件路径
cube_size: 立方体大小
Returns:
NoduleCube实例
"""
cube = cls(cube_size=cube_size, npy_path=file_path)
cube.load_from_npy()
return cube
@classmethod
def from_png(cls, file_path: str, cube_size: int = 64) -> 'NoduleCube':
"""
从PNG文件创建立方体实例
Args:
file_path: PNG文件路径
cube_size: 立方体大小
Returns:
NoduleCube实例
"""
cube = cls(cube_size=cube_size, png_path=file_path)
cube.load_from_png()
return cube
@classmethod
def from_array(cls,
pixel_data: np.ndarray,
center_x: int = 0,
center_y: int = 0,
center_z: int = 0,
radius: float = 0.0,
malignancy: int = 0) -> 'NoduleCube':
"""
从numpy数组创建立方体实例
Args:
pixel_data: 3D像素数据
center_x: 中心点X坐标
center_y: 中心点Y坐标
center_z: 中心点Z坐标
radius: 结节半径
malignancy: 恶性度(0=良性, 1=恶性)
Returns:
NoduleCube实例
"""
if len(pixel_data.shape) != 3:
raise ValueError(f"像素数据必须是3D数组,当前形状: {pixel_data.shape}")
cube_size = pixel_data.shape[0]
if pixel_data.shape[1] != cube_size or pixel_data.shape[2] != cube_size:
raise ValueError(f"像素数据必须是立方体形状,当前形状: {pixel_data.shape}")
cube = cls(
cube_size=cube_size,
center_x=center_x,
center_y=center_y,
center_z=center_z,
radius=radius,
malignancy=malignancy
)
cube.set_cube_data(pixel_data)
return cube
================================================
FILE: deploy/backend/dataclass/__init__.py
================================================
================================================
FILE: deploy/backend/detector.py
================================================
import os
import sys
import time
import logging
import numpy as np
import pandas as pd
import torch
from threading import Thread, Lock
import json
# 添加项目根目录到系统路径
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
# 导入肺结节检测模块
from inference.pytorch_nodule_detector import (
load_ct_data,
load_model,
get_lung_bounds,
scan_ct_data,
reduce_overlapping_nodules,
filter_false_positives,
format_results
)
# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# 检测会话状态
session_states = {}
session_locks = {}
# 检测状态常量
STATUS_LOADING = "loading" # 加载数据
STATUS_PREPROCESSING = "preprocessing" # 预处理
STATUS_SCANNING = "scanning" # 扫描检测
STATUS_FILTERING = "filtering" # 过滤假阳性
STATUS_COMPLETED = "completed" # 完成
STATUS_ERROR = "error" # 错误
class NoduleDetector:
"""肺结节检测器,包装pytorch_nodule_detector.py的功能"""
def __init__(self, model_path=None, device='cuda'):
"""初始化检测器
Args:
model_path: 模型路径
device: 使用设备 (cuda或cpu)
"""
self.model_path = model_path
self.device = torch.device(device if torch.cuda.is_available() and device == 'cuda' else 'cpu')
self.model = None
print("给出的模型路径是\t", model_path)
# 如果提供了模型路径,加载模型
if model_path and os.path.exists(model_path):
print("模型已经完成加载!")
self.load_model(model_path)
# 添加完成回调
self.completion_callback = None
def load_model(self, model_path):
"""加载模型
Args:
model_path: 模型路径
Returns:
加载是否成功
"""
try:
if not os.path.exists(model_path):
logger.error(f"模型文件不存在: {model_path}")
return False
self.model, self.device = load_model(model_path, self.device)
self.model_path = model_path
return True
except Exception as e:
logger.error(f"加载模型时出错: {e}")
return False
def detect(self, file_path, session_id, patient_id=None):
"""启动肺结节检测
Args:
file_path: CT文件或文件夹路径
session_id: 会话ID
patient_id: 患者ID
Returns:
布尔值,表示检测是否成功启动
"""
if not self.model:
logger.error("模型未加载")
return False
if session_id in session_states:
logger.warning(f"会话 {session_id} 已存在,将被覆盖")
print("进入detect 中的 file_path\t", file_path)
# 初始化会话状态
session_states[session_id] = {
"status": STATUS_LOADING,
"progress": 0,
"message": "正在加载CT数据...",
"started_at": time.time(),
"patient_id": patient_id,
"file_path": file_path,
"ct_data": None,
"nodules": None,
"lung_bounds": None,
"error": None
}
# 创建锁
if session_id not in session_locks:
session_locks[session_id] = Lock()
# 启动检测线程
thread = Thread(target=self._detect_thread, args=(file_path, session_id, patient_id))
thread.daemon = True
thread.start()
return True
def _detect_thread(self, file_path, session_id, patient_id):
"""检测线程
Args:
file_path: CT文件或文件夹路径
session_id: 会话ID
patient_id: 患者ID
"""
try:
# 加载CT数据
self._update_session(session_id, {
"status": STATUS_LOADING,
"progress": 0,
"message": "正在加载CT数据..."
})
ct_data = load_ct_data(file_path)
# 检查CT数据是否加载成功
if ct_data is None:
raise ValueError("CT数据加载失败")
# 更新会话状态
self._update_session(session_id, {
"ct_data": ct_data,
"progress": 10,
"message": "CT数据加载完成,开始进行肺部分割..."
})
# 获取肺部边界信息
self._update_session(session_id, {
"status": STATUS_PREPROCESSING,
"progress": 20,
"message": "正在进行肺部分割..."
})
# 肺部分割在load_ct_data中已完成,这里获取肺部边界信息
lung_bounds = get_lung_bounds(ct_data.lung_seg_mask)
if lung_bounds is None:
raise ValueError("未能找到有效的肺部区域")
# 更新会话状态
self._update_session(session_id, {
"lung_bounds": lung_bounds,
"progress": 30,
"message": "肺部分割完成,开始检测结节..."
})
# 开始扫描检测
self._update_session(session_id, {
"status": STATUS_SCANNING,
"progress": 40,
"message": "正在扫描肺部区域,检测结节..."
})
# 创建logger用于接收进度更新
progress_logger = self._create_progress_logger(session_id)
# 执行扫描
results_df = scan_ct_data(ct_data, self.model, self.device, progress_logger)
# 更新会话状态
self._update_session(session_id, {
"progress": 80,
"message": f"扫描完成,初步检测到 {len(results_df)} 个可能的结节,进行假阳性过滤..."
})
# 合并重叠结节
self._update_session(session_id, {
"status": STATUS_FILTERING,
"progress": 85,
"message": "正在合并重叠结节..."
})
reduced_df = reduce_overlapping_nodules(results_df)
# 过滤假阳性
self._update_session(session_id, {
"progress": 90,
"message": f"初步合并后剩余 {len(reduced_df)} 个结节,正在进行假阳性过滤..."
})
filtered_df = filter_false_positives(reduced_df, ct_data)
# 格式化结果
self._update_session(session_id, {
"progress": 95,
"message": f"过滤完成,最终检测到 {len(filtered_df)} 个结节,正在生成结果..."
})
final_df = format_results(filtered_df, ct_data, patient_id or session_id)
# 提取结节立方体
nodules = self._extract_nodule_cubes(ct_data, final_df)
# 更新会话状态
final_results = {
"status": STATUS_COMPLETED,
"progress": 100,
"message": f"检测完成,共发现 {len(nodules)} 个结节",
"nodules": nodules,
"completed_at": time.time()
}
self._update_session(session_id, final_results)
# 调用完成回调函数
if self.completion_callback:
try:
self.completion_callback(session_id, {"nodules": nodules})
except Exception as callback_error:
logger.error(f"执行完成回调函数时出错: {str(callback_error)}", exc_info=True)
except Exception as e:
logger.error(f"检测过程出错: {str(e)}", exc_info=True)
self._update_session(session_id, {
"status": STATUS_ERROR,
"message": f"检测失败: {str(e)}",
"error": str(e)
})
def _create_progress_logger(self, session_id):
"""创建用于接收进度更新的logger
Args:
session_id: 会话ID
Returns:
自定义logger对象
"""
class ProgressLogger:
def __init__(self, outer_self, session_id):
self.outer_self = outer_self
self.session_id = session_id
self.last_progress_time = time.time()
def info(self, message):
# 解析进度信息
if "处理进度:" in message:
try:
# 从消息中提取进度百分比
progress_part = message.split("处理进度:")[1].split("%")[0].strip()
progress_parts = progress_part.split('/')
if len(progress_parts) == 2:
current, total = map(int, progress_parts)
progress = min(40 + int(current / total * 40), 80) # 扫描进度从40%到80%
# 每秒最多更新一次进度
current_time = time.time()
if current_time - self.last_progress_time > 1:
self.last_progress_time = current_time
self.outer_self._update_session(self.session_id, {
"progress": progress,
"message": message
})
except:
pass
# 记录其他重要消息
if "扫描完成" in message:
self.outer_self._update_session(self.session_id, {
"progress": 80,
"message": message
})
return ProgressLogger(self, session_id)
def _extract_nodule_cubes(self, ct_data, nodules_df, cube_size=32):
"""从CT数据中提取结节立方体
Args:
ct_data: CTData对象
nodules_df: 结节DataFrame
cube_size: 立方体大小
Returns:
结节列表,每个结节包含立方体数据和相关信息
"""
nodule_list = []
for _, row in nodules_df.iterrows():
try:
# 获取体素坐标
x = int(row['voxel_x'])
y = int(row['voxel_y'])
z = int(row['voxel_z'])
# 计算立方体区域边界
half_size = cube_size // 2
z_min = max(0, z - half_size)
y_min = max(0, y - half_size)
x_min = max(0, x - half_size)
z_max = min(ct_data.lung_seg_img.shape[0], z + half_size)
y_max = min(ct_data.lung_seg_img.shape[1], y + half_size)
x_max = min(ct_data.lung_seg_img.shape[2], x + half_size)
# 提取立方体
cube = ct_data.lung_seg_img[z_min:z_max, y_min:y_max, x_min:x_max]
# 如果立方体大小不符合要求,进行调整
if cube.shape != (cube_size, cube_size, cube_size):
# 使用零填充调整大小
padded_cube = np.zeros((cube_size, cube_size, cube_size), dtype=cube.dtype)
padded_cube[:min(cube_size, cube.shape[0]),
:min(cube_size, cube.shape[1]),
:min(cube_size, cube.shape[2])] = cube[:min(cube_size, cube.shape[0]),
:min(cube_size, cube.shape[1]),
:min(cube_size, cube.shape[2])]
cube = padded_cube
# 添加结节信息
nodule_info = {
'id': int(row['nodule_id']),
'cube': cube.tolist(), # 转换为列表以便JSON序列化
'voxel_coords': [x, y, z],
'world_coords': [float(row['world_x']), float(row['world_y']), float(row['world_z'])],
'diameter_mm': float(row['diameter_mm']),
'probability': float(row['prob'])
}
nodule_list.append(nodule_info)
except Exception as e:
logger.error(f"提取结节立方体时出错: {str(e)}", exc_info=True)
return nodule_list
def _update_session(self, session_id, updates):
"""更新会话状态
Args:
session_id: 会话ID
updates: 要更新的字段
"""
if session_id not in session_states:
logger.warning(f"会话 {session_id} 不存在")
return
# 获取锁
with session_locks[session_id]:
for key, value in updates.items():
session_states[session_id][key] = value
def get_session_state(self, session_id):
"""获取会话状态
Args:
session_id: 会话ID
Returns:
会话状态字典
"""
if session_id not in session_states:
return None
# 获取锁
with session_locks[session_id]:
# 创建副本以避免修改原始状态
state = session_states[session_id].copy()
# 移除不可序列化的字段
if 'ct_data' in state:
del state['ct_data']
if 'lung_bounds' in state:
bounds_info = {}
if state['lung_bounds']:
bounds_info = {
'z_min': state['lung_bounds']['z_min'],
'z_max': state['lung_bounds']['z_max'],
'region_count': len(state['lung_bounds']['regions'])
}
state['lung_bounds'] = bounds_info
return state
def set_completion_callback(self, callback):
"""设置检测完成后的回调函数
Args:
callback: 回调函数,接收session_id和results参数
"""
self.completion_callback = callback
def get_detector_instance(model_path=None):
"""获取检测器实例(单例模式)
Args:
model_path: 模型路径
Returns:
NoduleDetector实例
"""
if not hasattr(get_detector_instance, 'instance'):
get_detector_instance.instance = NoduleDetector(model_path)
elif model_path and get_detector_instance.instance.model_path != model_path:
# 如果提供了不同的模型路径,重新加载模型
get_detector_instance.instance.load_model(model_path)
return get_detector_instance.instance
================================================
FILE: deploy/backend/models/pytorch_c3d_tiny.py
================================================
import torch.nn as nn
import torchvision.transforms as transforms
# my_tranform =transforms.Compose([
# # transforms.Resize((32,32,32)),
# transforms.ToTensor(),
# transforms.Normalize((0.5,0.5,0.5), (0.5, 0.5,0.5))
# ])
class C3dTiny(nn.Module):
def __init__(self):
super().__init__()
# 第一个3d卷积组
self.conv_block1 = nn.Sequential(
nn.Conv3d(in_channels=1, kernel_size=3, padding = 1, out_channels=64),
# 原网络结构没有,新增的
nn.BatchNorm3d(64),
nn.ReLU(),
nn.MaxPool3d(kernel_size=(1,2,2), stride = (1,2,2))
)
#
self.conv_block2 = nn.Sequential(
nn.Conv3d(in_channels=64, kernel_size=3, padding = 1, out_channels=128),
# 原网络结构没有,新增的
nn.BatchNorm3d(128),
nn.ReLU(),
nn.MaxPool3d(kernel_size=2)
)
self.drop_out1 = nn.Dropout(0.2)
#
self.conv_block3 = nn.Sequential(
nn.Conv3d(in_channels = 128, kernel_size=3, padding = 1, out_channels=256),
nn.BatchNorm3d(256),
nn.ReLU(),
nn.Conv3d(in_channels=256, kernel_size=3, padding = 1, out_channels=256),
nn.BatchNorm3d(256),
nn.ReLU(),
nn.MaxPool3d(kernel_size=2)
)
self.drop_out2 = nn.Dropout(0.2)
#
self.conv_block4 = nn.Sequential(
nn.Conv3d(in_channels = 256, kernel_size = 3, padding = 1, out_channels=512),
nn.BatchNorm3d(512),
nn.ReLU(),
nn.Conv3d(in_channels = 512, kernel_size = 3, padding = 1, out_channels = 512),
nn.BatchNorm3d(512),
nn.ReLU(),
nn.MaxPool3d(kernel_size=2)
)
self.drop_out3 = nn.Dropout(0.2)
self.flatten = nn.Flatten()
#计算输入特征数量:
# 原始输入为32x32x32,经过pool1(1,2,2)后变为32x16x16
# 经过pool2(2,2,2)后变为16x8x8
# 经过pool3(2,2,2)后变为8x4x4
# 经过pool4(2,2,2)后变为4x2x2
# 因此最终特征图大小为4x2x2,通道数为512
self.fc1 = nn.Sequential(
nn.Linear(512 * 4 * 2 * 2, 512),
nn.ReLU()
)
self.fc2 = nn.Linear(512, 2)
def forward(self, x):
x = self.conv_block1(x)
x = self.conv_block2(x)
x = self.drop_out1(x)
x = self.conv_block3(x)
x = self.drop_out2(x)
x = self.conv_block4(x)
x = self.drop_out3(x)
x = self.flatten(x)
x = self.fc1(x)
x = self.fc2(x)
return x
================================================
FILE: deploy/backend/models/pytorch_nodule_detector.py
================================================
import os
import numpy as np
import pandas as pd
import torch
from torch.nn import functional as F
from datetime import datetime
import logging
import time
from scipy import ndimage
from data.dataclass.CTData import CTData
from data.dataclass.NoduleCube import normal_cube_to_tensor
from deploy.backend.preprocessing.luna16_invalid_nodule_filter import nodule_valid
from models.pytorch_c3d_tiny import C3dTiny
# 推理参数
CUBE_SIZE = 32 # 扫描立方体大小 32x32x32
SCAN_STEP = 10 # 扫描步长,每次移动10个像素
PROB_THRESHOLD = 0.8 # 阈值: 大于此概率才视为结节
# 设置日志
def setup_logger(log_dir="./inference_logs"):
"""设置日志配置"""
os.makedirs(log_dir, exist_ok=True)
log_file = os.path.join(log_dir, f"inference_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log")
# 创建logger
logger = logging.getLogger('nodule_detection')
logger.setLevel(logging.INFO)
# 创建文件处理器
file_handler = logging.FileHandler(log_file)
file_handler.setLevel(logging.INFO)
# 创建控制台处理器
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
# 创建格式器
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
file_handler.setFormatter(formatter)
console_handler.setFormatter(formatter)
# 添加处理器
logger.addHandler(file_handler)
logger.addHandler(console_handler)
return logger
def load_ct_data(file_path):
"""
加载CT数据(支持MHD和DICOM)并进行预处理
Args:
file_path: CT文件或文件夹路径
Returns:
CTData对象
"""
# 判断是文件还是目录
if os.path.isfile(file_path):
# 假设是MHD文件
if file_path.endswith('.mhd'):
ct_data = CTData.from_mhd(file_path)
else:
raise ValueError(f"不支持的文件类型: {file_path}")
elif os.path.isdir(file_path):
# 假设是DICOM文件夹
ct_data = CTData.from_dicom(file_path)
else:
raise ValueError(f"指定路径不存在: {file_path}")
# 重采样到1mm间距
ct_data = ct_data.resample_pixel(new_spacing=[1, 1, 1])
# 肺部区域分割
ct_data.filter_lung_img_mask()
return ct_data
def load_model(evaL_model_path, device='cuda'):
"""
加载PyTorch模型
Args:
evaL_model_path: 模型权重文件路径
device: 计算设备 ('cuda' 或 'cpu')
Returns:
加载好权重的模型
"""
model = C3dTiny().to(device)
# 加载权重
model.load_state_dict(torch.load(evaL_model_path, map_location=device))
model.eval()
return model, device
def get_lung_bounds(lung_mask):
"""获取肺部掩码的边界框,考虑左右肺分离的情况"""
if lung_mask.sum() == 0:
return None
# 使用连通区域分析找出肺部区域
labeled_mask, num_features = ndimage.label(lung_mask > 0)
# 如果连通区域过多,只考虑最大的几个区域(通常是左右肺)
if num_features > 2:
# 计算每个标签区域的体素数量
region_sizes = np.array([(labeled_mask == i).sum() for i in range(1, num_features + 1)])
# 只保留最大的2个区域(左右肺)
valid_labels = np.argsort(region_sizes)[-2:] + 1
# 创建新的掩码,只包含最大的几个区域
refined_mask = np.zeros_like(labeled_mask)
for label in valid_labels:
refined_mask[labeled_mask == label] = 1
else:
refined_mask = lung_mask > 0
# 根据z轴切片,计算每个切片的肺部区域
z_ranges = []
margin = 5 # 切片边距
# 遍历每个z轴切片
for z in range(refined_mask.shape[0]):
slice_mask = refined_mask[z]
if slice_mask.sum() > 100: # 如果切片包含足够的肺部体素
y_indices, x_indices = np.where(slice_mask)
if len(y_indices) > 0:
y_min = max(0, y_indices.min() - margin)
y_max = min(refined_mask.shape[1], y_indices.max() + margin)
x_min = max(0, x_indices.min() - margin)
x_max = min(refined_mask.shape[2], x_indices.max() + margin)
z_ranges.append((z, y_min, y_max, x_min, x_max))
if not z_ranges:
return None
# 确定整体z轴范围
z_min = z_ranges[0][0]
z_max = z_ranges[-1][0] + 1
# 收集所有y和x范围
scan_regions = []
for z_slice, y_min, y_max, x_min, x_max in z_ranges:
scan_regions.append({
'z': z_slice,
'y_min': y_min,
'y_max': y_max,
'x_min': x_min,
'x_max': x_max
})
return {
'z_min': z_min,
'z_max': z_max,
'regions': scan_regions
}
def scan_ct_data(ct_data, model, device, logger, step=SCAN_STEP):
"""
扫描整个CT图像,预测结节位置 - 优化版
Args:
ct_data: CTData对象
model: PyTorch模型
device: 计算设备
logger: 日志对象
step: 扫描步长
Returns:
包含结节信息的DataFrame
"""
logger.info("开始扫描CT数据...")
# 获取肺部分割后的图像数据
lung_img = ct_data.lung_seg_img
lung_mask = ct_data.lung_seg_mask
# 获取肺部边界信息
bounds = get_lung_bounds(lung_mask)
if bounds is None:
logger.warning("未能找到有效的肺部区域")
return pd.DataFrame(columns=['voxel_coord_x', 'voxel_coord_y', 'voxel_coord_z',
'world_coord_x', 'world_coord_y', 'world_coord_z', 'prob'])
logger.info(f"已确定肺部区域: Z轴范围 {bounds['z_min']} 到 {bounds['z_max']}, 共 {len(bounds['regions'])} 个切片")
# 创建存储结果的列表
results = []
# 计算需要扫描的总体素数估计
total_voxels = 0
for region in bounds['regions']:
y_range = region['y_max'] - region['y_min']
x_range = region['x_max'] - region['x_min']
total_voxels += (y_range // step + 1) * (x_range // step + 1)
logger.info(f"预计扫描体素数: {total_voxels}")
# 开始计时
start_time = time.time()
batch_size = 32 # 增大批处理大小提高GPU利用率
batch_inputs = []
batch_positions = []
# 跟踪进度
processed_voxels = 0
skipped_voxels = 0
# 设置肺部组织比例阈值
lung_tissue_threshold = 0.1 # 立方体中肺部组织的最小比例
# 逐切片扫描肺部区域
for z_idx, region in enumerate(bounds['regions']):
z = region['z']
# 检查是否可以放置一个完整的立方体
if z + CUBE_SIZE > lung_img.shape[0]:
continue
# 在当前切片上扫描
for y in range(region['y_min'], region['y_max'] - CUBE_SIZE + 1, step):
for x in range(region['x_min'], region['x_max'] - CUBE_SIZE + 1, step):
# 提取当前位置的肺部掩码立方体
mask_cube = lung_mask[z:z+CUBE_SIZE, y:y+CUBE_SIZE, x:x+CUBE_SIZE]
# 计算肺部组织比例
lung_ratio = np.mean(mask_cube)
# 如果肺部组织比例过低,跳过
if lung_ratio < lung_tissue_threshold:
skipped_voxels += 1
continue
# 提取当前位置的立方体
cube = lung_img[z:z+CUBE_SIZE, y:y+CUBE_SIZE, x:x+CUBE_SIZE]
# 预处理立方体数据
cube_tensor = normal_cube_to_tensor(cube)
cube_tensor = cube_tensor.unsqueeze(0)
# 添加到批处理
batch_inputs.append(cube_tensor)
batch_positions.append((z, y, x))
# 当批处理达到指定大小时进行预测
if len(batch_inputs) == batch_size:
# 处理当前批次
process_batch(batch_inputs, batch_positions, model, device, ct_data, results)
batch_inputs = []
batch_positions = []
processed_voxels += 1
# 定期报告进度
if (processed_voxels + skipped_voxels) % 1000 == 0:
elapsed_time = time.time() - start_time
progress = processed_voxels / total_voxels * 100 if total_voxels > 0 else 0
logger.info(f"处理进度: {processed_voxels}/{total_voxels} ({progress:.2f}%), "
f"已跳过: {skipped_voxels}, 耗时: {elapsed_time:.2f}秒")
# 处理最后一个批次
if batch_inputs:
process_batch(batch_inputs, batch_positions, model, device, ct_data, results)
# 创建DataFrame
if results:
results_df = pd.DataFrame(results)
logger.info(f"扫描完成! 发现 {len(results_df)} 个可能的结节")
else:
results_df = pd.DataFrame(columns=['voxel_coord_x', 'voxel_coord_y', 'voxel_coord_z',
'world_coord_x', 'world_coord_y', 'world_coord_z', 'prob'])
logger.info("扫描完成! 未发现任何结节")
return results_df
def process_batch(batch_inputs, batch_positions, model, device, ct_data, results):
"""处理一个批次的数据"""
# 合并批处理
batch_tensor = torch.cat(batch_inputs, dim=0).to(device)
# 预测
with torch.no_grad():
batch_outputs = model(batch_tensor)
batch_probs = F.softmax(batch_outputs, dim=1)[:, 1] # 类别1的概率
# 处理每个预测结果
for i, prob in enumerate(batch_probs):
prob_value = prob.item()
if prob_value > PROB_THRESHOLD:
z_pos, y_pos, x_pos = batch_positions[i]
# 计算中心点坐标
center_z = z_pos + CUBE_SIZE // 2
center_y = y_pos + CUBE_SIZE // 2
center_x = x_pos + CUBE_SIZE // 2
# 将体素坐标转换为世界坐标 (mm)
world_coord = ct_data.voxel_to_world([center_x, center_y, center_z])
# 添加结果
results.append({
'voxel_coord_x': center_x,
'voxel_coord_y': center_y,
'voxel_coord_z': center_z,
'world_coord_x': world_coord[0],
'world_coord_y': world_coord[1],
'world_coord_z': world_coord[2],
'prob': prob_value
})
def reduce_overlapping_nodules(results_df, distance_threshold=15):
"""
合并重叠的结节预测,使用更严格的距离阈值
Args:
results_df: 包含结节预测的DataFrame
distance_threshold: 合并的距离阈值(体素)
Returns:
合并后的结节DataFrame
"""
if len(results_df) <= 1:
return results_df
# 按概率从高到低排序
sorted_df = results_df.sort_values('prob', ascending=False).reset_index(drop=True)
# 创建一个布尔掩码来标记要保留的行
keep_mask = np.ones(len(sorted_df), dtype=bool)
# 对每一行
for i in range(len(sorted_df)):
if not keep_mask[i]:
continue # 如果此行已被标记为删除,则跳过
# 获取当前结节的坐标
current = sorted_df.iloc[i]
# 比较与其他所有结节的距离
for j in range(i + 1, len(sorted_df)):
if not keep_mask[j]:
continue # 如果要比较的行已被标记为删除,则跳过
# 获取要比较的结节坐标
compare = sorted_df.iloc[j]
# 计算3D欧氏距离
distance = np.sqrt(
(current['voxel_coord_x'] - compare['voxel_coord_x']) ** 2 +
(current['voxel_coord_y'] - compare['voxel_coord_y']) ** 2 +
(current['voxel_coord_z'] - compare['voxel_coord_z']) ** 2
)
# 如果距离小于阈值,标记为删除
if distance < distance_threshold:
keep_mask[j] = False
# 应用掩码,仅保留未被标记为删除的行
reduced_df = sorted_df[keep_mask].reset_index(drop=True)
return reduced_df
def filter_false_positives(nodules_df, ct_data, max_nodules=10):
"""
基于解剖学和统计特征过滤假阳性结节
Args:
nodules_df: 包含结节预测的DataFrame
ct_data: CTData对象
max_nodules: 每个患者允许的最大结节数量
Returns:
过滤后的结节DataFrame
"""
if nodules_df.empty:
return nodules_df
# 获取肺部掩码
lung_mask = ct_data.lung_seg_mask
# 1. 限制结节总数
if len(nodules_df) > max_nodules:
# 只保留概率最高的前N个结节
nodules_df = nodules_df.sort_values('prob', ascending=False).head(max_nodules)
# 2. 基于位置过滤
filtered_rows = []
for i, row in nodules_df.iterrows():
x, y, z = int(row['voxel_coord_x']), int(row['voxel_coord_y']), int(row['voxel_coord_z'])
this_nodule_valid = nodule_valid(ct_data, x, y, z)
if this_nodule_valid:
# 通过所有检查,保留此结节
filtered_rows.append(row)
# 创建新的DataFrame
filtered_df = pd.DataFrame(filtered_rows)
# 3. 基于概率再次过滤
# 如果概率低于阈值,移除
# high_prob_threshold = 0.95 # 高概率阈值
# filtered_df = filtered_df[filtered_df['prob'] >= high_prob_threshold]
return filtered_df
def format_results(results_df, ct_data, patient_id):
"""
格式化结果为最终输出的DataFrame
Args:
results_df: 合并后的结节DataFrame
ct_data: CTData对象
patient_id: 患者ID
Returns:
包含结节信息的最终DataFrame
"""
# 如果没有结节,返回空的DataFrame
if results_df.empty:
return pd.DataFrame(columns=['patient_id', 'nodule_id', 'voxel_x', 'voxel_y', 'voxel_z',
'world_x', 'world_y', 'world_z', 'diameter_mm', 'prob'])
# 创建最终结果列表
final_results = []
# 处理每个结节
for i, row in results_df.iterrows():
# 设置默认直径为CUBE_SIZE / 2
diameter_mm = CUBE_SIZE / 2
# 添加结果
final_results.append({
'patient_id': patient_id,
'nodule_id': i + 1,
'voxel_x': int(row['voxel_coord_x']),
'voxel_y': int(row['voxel_coord_y']),
'voxel_z': int(row['voxel_coord_z']),
'world_x': row['world_coord_x'],
'world_y': row['world_coord_y'],
'world_z': row['world_coord_z'],
'diameter_mm': diameter_mm,
'prob': row['prob']
})
# 创建DataFrame
final_df = pd.DataFrame(final_results)
return final_df
def detect_nodules(file_path, model_path, detect_patient_id=None, device='cuda'):
"""
主函数:对CT数据进行结节检测
Args:
file_path: CT文件或文件夹路径
model_path: 模型权重文件路径
detect_patient_id: 患者ID,如果为None则使用文件名
device: 计算设备 ('cuda' 或 'cpu')
Returns:
包含结节信息的DataFrame
"""
# 设置日志
logger = setup_logger()
# 如果患者ID为None,则使用文件名
if detect_patient_id is None:
if os.path.isfile(file_path):
detect_patient_id = os.path.splitext(os.path.basename(file_path))[0]
else:
detect_patient_id = os.path.basename(file_path)
logger.info(f"开始处理患者 {detect_patient_id} 的CT数据")
try:
# 加载CT数据
logger.info(f"加载CT数据: {file_path}")
ct_data = load_ct_data(file_path)
# 加载模型
logger.info(f"加载模型: {model_path}")
model, device = load_model(model_path, device)
# 扫描CT数据
results_df = scan_ct_data(ct_data, model, device, logger)
# 合并重叠结节
logger.info("合并重叠结节...")
reduced_df = reduce_overlapping_nodules(results_df)
logger.info(f"合并后的结节数量: {len(reduced_df)}")
# 过滤假阳性
logger.info("过滤假阳性结节...")
filtered_df = filter_false_positives(reduced_df, ct_data)
logger.info(f"过滤后的结节数量: {len(filtered_df)}")
# 格式化结果
final_df = format_results(filtered_df, ct_data, patient_id)
logger.info(f"检测完成,找到 {len(final_df)} 个结节")
return final_df
except Exception as e:
logger.error(f"检测过程中出错: {str(e)}", exc_info=True)
raise
if __name__ == "__main__":
test_mhd = "H:/luna16/subset8/1.3.6.1.4.1.14519.5.2.1.6279.6001.149041668385192796520281592139.mhd"
model_path = "../training/pytorch_checkpoints/best_model.pth"
threshold = 0.7
patient_id = "1.3.6.1.4.1.14519.5.2.1.6279.6001.149041668385192796520281592139"
detect_result_csv = "./c3d_classify_result-%s.csv" %patient_id
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 运行检测
result_df = detect_nodules(test_mhd, model_path, None, device)
# 保存结果
result_df.to_csv(detect_result_csv, index=False, encoding="utf-8")
================================================
FILE: deploy/backend/preprocessing/__init__.py
================================================
================================================
FILE: deploy/backend/preprocessing/luna16_invalid_nodule_filter.py
================================================
## 去掉 Luna2016 候选结节数据中 有问题的标注数据 以及 用在预测过程中的 错误结节
import numpy as np
def nodule_valid(ct_data, voxel_coord_x, voxel_coord_y,voxel_coord_z):
"""
判定当前结节是否 可以用来做训练cube 或者 扫描得到的cube 是否
:param ct_data: 已经转换为0-255 ,并且已经抽取到肺部区域数据的 CTData类
:param voxel_coord_x: 当前要判定的 cube的坐标中心位置
:param voxel_coord_y:
:param voxel_coord_z:
:return: 当前结节是否可用 True(可用) / False (不可用)
"""
lung_mask = ct_data.lung_seg_mask
# 检查坐标是否在肺部边界内
if (voxel_coord_z < 0 or voxel_coord_z >= lung_mask.shape[0] or
voxel_coord_y < 0 or voxel_coord_y >= lung_mask.shape[1] or
voxel_coord_x < 0 or voxel_coord_x >= lung_mask.shape[2]):
return False
# 获取周围半径为5个体素的区域
z_min = max(0, voxel_coord_z - 5)
z_max = min(lung_mask.shape[0], voxel_coord_z + 6)
y_min = max(0, voxel_coord_y - 5)
y_max = min(lung_mask.shape[1], voxel_coord_y + 6)
x_min = max(0, voxel_coord_x - 5)
x_max = min(lung_mask.shape[2], voxel_coord_x + 6)
# 提取周围区域的肺部掩码
neighborhood_mask = lung_mask[z_min:z_max, y_min:y_max, x_min:x_max]
# 计算肺部组织占比
lung_ratio = np.mean(neighborhood_mask)
# 如果周围区域肺部组织占比太低,可能是假阳性
if lung_ratio < 0.5:
return False
# 检查是否在肺部边缘
# 计算当前点在肺部掩码中的位置
if (0 < voxel_coord_z < lung_mask.shape[0] - 1 and
0 < voxel_coord_y < lung_mask.shape[1] - 1 and
0 < voxel_coord_x < lung_mask.shape[2] - 1):
# 计算6-邻域(上下左右前后)中肺部体素的数量
neighbors = [
lung_mask[voxel_coord_z - 1, voxel_coord_y, voxel_coord_x],
lung_mask[voxel_coord_z + 1, voxel_coord_y, voxel_coord_x],
lung_mask[voxel_coord_z, voxel_coord_y - 1, voxel_coord_x],
lung_mask[voxel_coord_z, voxel_coord_y + 1, voxel_coord_x],
lung_mask[voxel_coord_z, voxel_coord_y, voxel_coord_x - 1],
lung_mask[voxel_coord_z, voxel_coord_y, voxel_coord_x + 1]
]
# 如果邻域中有过多非肺部体素,说明这可能是在肺部边缘
if sum(neighbors) < 4:
return False
return True
================================================
FILE: deploy/backend/util/__init__.py
================================================
================================================
FILE: deploy/backend/util/dicom_util.py
================================================
import os
import glob
import pydicom
import numpy as np
import cv2
from tqdm import tqdm
from util.seg_util import get_segmented_lungs,normalize_hu_values
from util.image_util import rescale_patient_images
def is_dicom_file(filename):
'''
if current file is a dicom file
:param filename: file need to be judged
:return:
'''
file_stream = open(filename, 'rb')
file_stream.seek(128)
data = file_stream.read(4)
file_stream.close()
if data == b'DICM':
return True
return False
def get_dicom_thickness(dicom_slices):
"""
计算切片厚度
:param dicom_slices: dicom 读取的 dicom数据
:return:
"""
if len(dicom_slices) > 1:
try:
slice_thickness = abs(dicom_slices[0].ImagePositionPatient[2] - dicom_slices[1].ImagePositionPatient[2])
except:
try:
slice_thickness = abs(dicom_slices[0].SliceLocation - dicom_slices[1].SliceLocation)
except:
# 如果无法计算,尝试从SliceThickness标签中获取
try:
slice_thickness = float(dicom_slices[0].SliceThickness)
except:
print("警告: 无法确定切片厚度,使用默认值1.0mm")
slice_thickness = 1.0
else:
try:
slice_thickness = float(dicom_slices[0].SliceThickness)
except:
print("警告: 只有一个切片,无法计算切片厚度,使用默认值1.0mm")
slice_thickness = 1.0
return slice_thickness
def load_dicom_slices(dicom_path):
"""
load dicom file path and stack into list
:param dicom_path: a dicom path
:return: dicom list
"""
dicom_files = []
for root, _, files in os.walk(dicom_path):
for file in files:
if file.lower().endswith(('.dcm', '.dicom')):
real_file = os.path.join(dicom_path, root, file)
current_if_dicom = is_dicom_file(real_file)
if current_if_dicom:
dicom_files.append(real_file)
if not dicom_files:
raise ValueError(f"在路径 {dicom_path} 中未找到DICOM文件")
# 加载所有切片
slices = []
for file in dicom_files:
try:
ds = pydicom.dcmread(file)
slices.append(ds)
except Exception as e:
print(f"无法读取DICOM文件 {file}: {e}")
# # 按照Z轴位置排序切片
slices.sort(key=lambda x: int(x.InstanceNumber))
slice_thickness = get_dicom_thickness(slices)
for s in slices:
s.SliceThickness = slice_thickness
return slices
def get_pixels_hu(slices):
'''
transfer dicom array to pixel array,and remove border(HU==-2000)
:param slices: dicom list
:return: pixel array of one patient's dicom
'''
image = np.stack([s.pixel_array for s in slices])
image = image.astype(np.int16)
image[image == -2000] = 0
for slice_number in range(len(slices)):
intercept = slices[slice_number].RescaleIntercept
slope = slices[slice_number].RescaleSlope
if slope != 1:
image[slice_number] = slope * image[slice_number].astype(np.float64)
image[slice_number] = image[slice_number].astype(np.int16)
image[slice_number] += np.int16(intercept)
return np.array(image, dtype=np.int16)
def getinfo_dicom(dicom_path):
print('dicom_path: ', dicom_path)
slices = load_dicom_slices(dicom_path)
print(type(slices[0]), slices[0].ImagePositionPatient)
print(len(slices), "\t", slices[0].SliceThickness, "\t", slices[0].PixelSpacing)
print("Orientation: ", slices[0].ImageOrientationPatient)
#assert slices[0].ImageOrientationPatient == [1.000000, 0.000000, 0.000000, 0.000000, 1.000000, 0.000000]
pixels = get_pixels_hu(slices)
image = pixels
print(image.shape)
invert_order = slices[1].ImagePositionPatient[2] > slices[0].ImagePositionPatient[2]
print("Invert order: ", invert_order, " - ", slices[1].ImagePositionPatient[2], ",",
slices[0].ImagePositionPatient[2])
pixel_spacing = slices[0].PixelSpacing
pixel_spacing.append(slices[0].SliceThickness)
# save dicom source image size
dicom_size = [image.shape[0], image.shape[1], image.shape[2]]
return pixel_spacing, dicom_size, invert_order
def extract_dicom_images_patient(dicom_path, target_dir):
slices = load_dicom_slices(dicom_path)
assert slices[0].ImageOrientationPatient == [1.000000, 0.000000, 0.000000, 0.000000, 1.000000, 0.000000]
pixels = get_pixels_hu(slices)
image = pixels
invert_order = slices[1].ImagePositionPatient[2] > slices[0].ImagePositionPatient[2]
pixel_spacing = slices[0].PixelSpacing
pixel_spacing.append(slices[0].SliceThickness)
# save dicom source image size
dicom_size = [image.shape[0], image.shape[1], image.shape[2]]
image = rescale_patient_images(image, pixel_spacing)
png_size = [image.shape[0], image.shape[1], image.shape[2]]
if not invert_order:
image = np.flipud(image)
if not os.path.exists(target_dir):
os.mkdir(target_dir)
else:
print("png dir already exists, return directly")
return pixel_spacing, dicom_size, png_size, invert_order
png_files = glob.glob(target_dir + "*.png")
for file in png_files:
os.remove(file)
for i in tqdm(range(image.shape[0])):
img_path = patient_dir + "/img_" + str(i).rjust(4, '0') + "_i.png"
org_img = image[i]
img, mask = get_segmented_lungs(org_img.copy())
org_img = normalize_hu_values(org_img)
cv2.imwrite(img_path, org_img * 255)
cv2.imwrite(img_path.replace("_i.png", "_m.png"), mask * 255)
return pixel_spacing, dicom_size, png_size,invert_order
================================================
FILE: deploy/backend/util/image_util.py
================================================
from typing import Tuple
import cv2
import os
import numpy
import glob
import random
import numpy as np
from scipy import ndimage
def get_normalized_img_unit8(img):
img = img.astype(numpy.float)
min = img.min()
max = img.max()
img -= min
img /= max - min
img *= 255
res = img.astype(numpy.uint8)
return res
def load_patient_images(png_path, wildcard="*.*", exclude_wildcards=[]):
print("png path is\t",png_path)
src_dir = png_path
src_img_paths = glob.glob(src_dir +'/'+ wildcard)
for exclude_wildcard in exclude_wildcards:
exclude_img_paths = glob.glob(src_dir + exclude_wildcard)
src_img_paths = [im for im in src_img_paths if im not in exclude_img_paths]
src_img_paths.sort()
images = [cv2.imread(img_path, cv2.IMREAD_GRAYSCALE) for img_path in src_img_paths]
images = [im.reshape((1, ) + im.shape) for im in images]
res = numpy.vstack(images)
return res
def draw_overlay(png_path: str, p_x: float, p_y: float, p_z: float, index: str, BOX_size:int = 20) -> None:
"""
在图像上绘制覆盖层
Args:
png_path: PNG图像路径
p_x: X坐标(百分比)
p_y: Y坐标(百分比)
p_z: Z坐标(百分比)
index: 索引标识
:param BOX_size:
"""
patient_img = load_patient_images(png_path + "/png/", "*_i.png", [])
z = int(p_z * patient_img.shape[0])
y = int(p_y * patient_img.shape[1])
x = int(p_x * patient_img.shape[2])
# 包围盒大小
x1 = x - BOX_size
y1 = y - BOX_size
x2 = x + BOX_size
y2 = y + BOX_size
target_img = patient_img[z, :, :]
cv2.rectangle(target_img, (x1, y1), (x2, y2), (255, 0, 0), 1)
cv2.imwrite(png_path + "/" + index + ".png", target_img)
def prepare_image_for_net3D(img,MEAN_PIXEL_VALUE = 41):
'''
normalization of image (average and zero center)
:param img: image to be normalization
:param MEAN_PIXEL_VALUE:
:return:
'''
img = img.astype(numpy.float32)
img -= MEAN_PIXEL_VALUE
img /= 255.
img = img.reshape(1, img.shape[0], img.shape[1], img.shape[2], 1)
return img
def move_png2dir(target_dir):
import shutil
first_dir = []
for path in os.listdir(target_dir):
if os.path.isdir(os.path.join(target_dir,path)):
first_dir.append(os.path.join(target_dir,path))
for d in first_dir:
tmp_path = []
for file in os.listdir(d):
tmp_file_path = os.path.join(d,file)
png_path = os.path.join(d,'png')
if not os.path.exists(png_path):
os.mkdir(png_path)
if tmp_file_path.endswith(".png"):
shutil.move(tmp_file_path,os.path.join(png_path,file))
print("move file from %s to %s " %(tmp_file_path,os.path.join(png_path,file)))
def rescale_patient_images(images_zyx, org_spacing_xyz, target_voxel_mm =1.0, is_mask_image=False, verbose=False):
'''
rescale a 3D image to specified size
:param images_zyx: source image
:param org_spacing_xyz:
:param target_voxel_mm:
:param is_mask_image:
:param verbose:
:return:
'''
if verbose:
print("Spacing: ", org_spacing_xyz)
print("Shape: ", images_zyx.shape)
# print "Resizing dim z"
resize_x = 1.0
resize_y = float(org_spacing_xyz[2]) / float(target_voxel_mm)
interpolation = cv2.INTER_NEAREST if is_mask_image else cv2.INTER_LINEAR
res = cv2.resize(images_zyx, dsize=None, fx=resize_x, fy=resize_y, interpolation=interpolation) # opencv assumes y, x, channels umpy array, so y = z pfff
res = res.swapaxes(0, 2)
res = res.swapaxes(0, 1)
# print "Shape: ", res.shape
resize_x = float(org_spacing_xyz[0]) / float(target_voxel_mm)
resize_y = float(org_spacing_xyz[1]) / float(target_voxel_mm)
# cv2 can handle max 512 channels..
if res.shape[2] > 512:
res = res.swapaxes(0, 2)
res1 = res[:256]
res2 = res[256:]
res1 = res1.swapaxes(0, 2)
res2 = res2.swapaxes(0, 2)
res1 = cv2.resize(res1, dsize=None, fx=resize_x, fy=resize_y, interpolation=interpolation)
res2 = cv2.resize(res2, dsize=None, fx=resize_x, fy=resize_y, interpolation=interpolation)
res1 = res1.swapaxes(0, 2)
res2 = res2.swapaxes(0, 2)
res = numpy.vstack([res1, res2])
res = res.swapaxes(0, 2)
else:
res = cv2.resize(res, dsize=None, fx=resize_x, fy=resize_y, interpolation=interpolation)
res = res.swapaxes(0, 2)
res = res.swapaxes(2, 1)
if verbose:
print("Shape after: ", res.shape)
return res
def rescale_patient_images2(images_zyx, target_shape, verbose=False):
if verbose:
print("Target: ", target_shape)
print("Shape: ", images_zyx.shape)
# print "Resizing dim z"
resize_x = 1.0
interpolation = cv2.INTER_NEAREST if False else cv2.INTER_LINEAR
res = cv2.resize(images_zyx, dsize=(target_shape[1], target_shape[0]), interpolation=interpolation) # opencv assumes y, x, channels umpy array, so y = z pfff
# print "Shape is now : ", res.shape
res = res.swapaxes(0, 2)
res = res.swapaxes(0, 1)
# cv2 can handle max 512 channels..
if res.shape[2] > 512:
res = res.swapaxes(0, 2)
res1 = res[:256]
res2 = res[256:]
res1 = res1.swapaxes(0, 2)
res2 = res2.swapaxes(0, 2)
res1 = cv2.resize(res1, dsize=(target_shape[2], target_shape[1]), interpolation=interpolation)
res2 = cv2.resize(res2, dsize=(target_shape[2], target_shape[1]), interpolation=interpolation)
res1 = res1.swapaxes(0, 2)
res2 = res2.swapaxes(0, 2)
res = numpy.vstack([res1, res2])
res = res.swapaxes(0, 2)
else:
res = cv2.resize(res, dsize=(target_shape[2], target_shape[1]), interpolation=interpolation)
res = res.swapaxes(0, 2)
res = res.swapaxes(2, 1)
if verbose:
print("Shape after: ", res.shape)
return res
def resize_image(image: np.ndarray, new_shape: Tuple[int, ...]) -> np.ndarray:
"""
调整图像大小
Args:
image: 输入图像
new_shape: 新形状
Returns:
np.ndarray: 调整大小后的图像
"""
# 处理单通道或多通道图像
if len(image.shape) == 3 and len(new_shape) == 2:
# 处理3D图像调整为2D
resized_image = np.zeros((image.shape[0], new_shape[0], new_shape[1]))
for i in range(image.shape[0]):
resized_image[i] = cv2.resize(image[i], (new_shape[1], new_shape[0]))
return resized_image
elif len(image.shape) == 2 and len(new_shape) == 2:
# 处理2D图像
return cv2.resize(image, (new_shape[1], new_shape[0]))
else:
# 处理任意维度图像
resize_factor = tuple(n / o for n, o in zip(new_shape, image.shape))
return ndimage.zoom(image, resize_factor, mode='nearest')
def cv_flip(img,cols,rows,degree):
'''
flip image by degree
:param img: image array to be fliped
:param cols: width of image
:param rows: height of image
:param degree: degree to flip
:return:
'''
M = cv2.getRotationMatrix2D((cols / 2, rows /2), degree, 1.0)
dst = cv2.warpAffine(img, M, (cols, rows))
return dst
def random_rotate_img(img, chance, min_angle, max_angle):
'''
random rotation an image
:param img: image to be rotated
:param chance: random probability
:param min_angle: min angle to rotate
:param max_angle: max angle to rotate
:return: image after random rotated
'''
import cv2
if random.random() > chance:
return img
if not isinstance(img, list):
img = [img]
angle = random.randint(min_angle, max_angle)
center = (img[0].shape[0] / 2, img[0].shape[1] / 2)
rot_matrix = cv2.getRotationMatrix2D(center, angle, scale=1.0)
res = []
for img_inst in img:
img_inst = cv2.warpAffine(img_inst, rot_matrix, dsize=img_inst.shape[:2], borderMode=cv2.BORDER_CONSTANT)
res.append(img_inst)
if len(res) == 0:
res = res[0]
return res
def random_flip_img(img, horizontal_chance=0, vertical_chance=0):
'''
random flip image,both on horizontal and vertical
:param img: image to be flipped
:param horizontal_chance: flip probability to flipped on horizontal direction
:param vertical_chance: flip probability to flipped on vertical direction
:return: image after flipped
'''
import cv2
flip_horizontal = False
if random.random() < horizontal_chance:
flip_horizontal = True
flip_vertical = False
if random.random() < vertical_chance:
flip_vertical = True
if not flip_horizontal and not flip_vertical:
return img
flip_val = 1
if flip_vertical:
flip_val = -1 if flip_horizontal else 0
if not isinstance(img, list):
res = cv2.flip(img, flip_val) # 0 = X axis, 1 = Y axis, -1 = both
else:
res = []
for img_item in img:
img_flip = cv2.flip(img_item, flip_val)
res.append(img_flip)
return res
def random_scale_img(img, xy_range, lock_xy=False):
if random.random() > xy_range.chance:
return img
if not isinstance(img, list):
img = [img]
import cv2
scale_x = random.uniform(xy_range.x_min, xy_range.x_max)
scale_y = random.uniform(xy_range.y_min, xy_range.y_max)
if lock_xy:
scale_y = scale_x
org_height, org_width = img[0].shape[:2]
xy_range.last_x = scale_x
xy_range.last_y = scale_y
res = []
for img_inst in img:
scaled_width = int(org_width * scale_x)
scaled_height = int(org_height * scale_y)
scaled_img = cv2.resize(img_inst, (scaled_width, scaled_height), interpolation=cv2.INTER_CUBIC)
if scaled_width < org_width:
extend_left = (org_width - scaled_width) / 2
extend_right = org_width - extend_left - scaled_width
scaled_img = cv2.copyMakeBorder(scaled_img, 0, 0, extend_left, extend_right, borderType=cv2.BORDER_CONSTANT)
scaled_width = org_width
if scaled_height < org_height:
extend_top = (org_height - scaled_height) / 2
extend_bottom = org_height - extend_top - scaled_height
scaled_img = cv2.copyMakeBorder(scaled_img, extend_top, extend_bottom, 0, 0, borderType=cv2.BORDER_CONSTANT)
scaled_height = org_height
start_x = (scaled_width - org_width) / 2
start_y = (scaled_height - org_height) / 2
tmp = scaled_img[start_y: start_y + org_height, start_x: start_x + org_width]
res.append(tmp)
return res
class XYRange:
def __init__(self, x_min, x_max, y_min, y_max, chance=1.0):
self.chance = chance
self.x_min = x_min
self.x_max = x_max
self.y_min = y_min
self.y_max = y_max
self.last_x = 0
self.last_y = 0
def get_last_xy_txt(self):
res = "x_" + str(int(self.last_x * 100)).replace("-", "m") + "-" + "y_" + str(int(self.last_y * 100)).replace(
"-", "m")
return res
def random_translate_img(img, xy_range, border_mode="constant"):
if random.random() > xy_range.chance:
return img
import cv2
if not isinstance(img, list):
img = [img]
org_height, org_width = img[0].shape[:2]
translate_x = random.randint(xy_range.x_min, xy_range.x_max)
translate_y = random.randint(xy_range.y_min, xy_range.y_max)
trans_matrix = numpy.float32([[1, 0, translate_x], [0, 1, translate_y]])
border_const = cv2.BORDER_CONSTANT
if border_mode == "reflect":
border_const = cv2.BORDER_REFLECT
res = []
for img_inst in img:
img_inst = cv2.warpAffine(img_inst, trans_matrix, (org_width, org_height), borderMode=border_const)
res.append(img_inst)
if len(res) == 1:
res = res[0]
xy_range.last_x = translate_x
xy_range.last_y = translate_y
return res
def data_augmentation(image: np.ndarray, augment_type: str = 'random') -> np.ndarray:
"""
对图像进行数据增强
Args:
image: 输入图像
augment_type: 增强类型,可选'random', 'flip', 'rotate', 'shift'
Returns:
np.ndarray: 增强后的图像
"""
if augment_type == 'random':
# 随机选择一种增强方式
augment_choices = ['flip', 'rotate', 'shift', 'none']
choice = np.random.choice(augment_choices)
if choice == 'flip':
return data_augmentation(image, 'flip')
elif choice == 'rotate':
return data_augmentation(image, 'rotate')
elif choice == 'shift':
return data_augmentation(image, 'shift')
else:
return image
elif augment_type == 'flip':
# 随机翻转
axis = np.random.randint(0, image.ndim)
return np.flip(image, axis=axis)
elif augment_type == 'rotate':
# 随机旋转
if image.ndim == 2:
angle = np.random.randint(0, 360)
return ndimage.rotate(image, angle, reshape=False, mode='nearest')
else:
# 3D旋转
axes = tuple(np.random.choice(range(image.ndim), size=2, replace=False))
angle = np.random.randint(0, 360)
return ndimage.rotate(image, angle, axes=axes, reshape=False, mode='nearest')
elif augment_type == 'shift':
# 随机平移
shift = np.random.randint(-5, 6, size=image.ndim)
return ndimage.shift(image, shift, mode='nearest')
return image
================================================
FILE: deploy/backend/util/mhd_util.py
================================================
import os
import ntpath
import SimpleITK
import numpy as np
import pandas as pd
import cv2
from data.dataclass.CTData import CTData
from util.seg_util import normalize_hu_values,get_segmented_lungs
from util import image_util
from constant import tianchi
TARGET_VOXEL_MM = 1.0
MHD_INFO_HEAD = "patient_id,shape_0,shape_1,shape_2,origin_x,origin_y,origin_z,direction_z(1_-1)," \
"spacing_x,spacing_y,spacing_z,rescale_x,rescale_y,rescale_z"
def get_all_mhd_file(BASE_DATA_DIR,base_head,max):
"""
get all mhd file list ,tianchi mhd file consist of train_subset00,train_subset01,... test_subset00,test_subset01,..
:param base_head: 'train' or 'test',or 'val', to construct train_subset00,test_subset01,val_subset02...
:param max: the max suffix of path ,such as train_subset09, then max=09
:return: all mhd file list
"""
mhd_files = []
for index in range(0,max):
if index<10:
index = "0"+str(index)
else:
index =str(index)
sub_path = os.path.join(BASE_DATA_DIR,base_head+"_subset"+index)
for name in os.listdir(sub_path):
if name.endswith(".mhd"):
mhd_files.append(os.path.join(sub_path,name))
return mhd_files
def get_luna16_mhd_file(mhd_root):
"""
get all mhd file list ,tianchi mhd file consist of train_subset00,train_subset01,... test_subset00,test_subset01,..
:param mhd_root: 'train' or 'test',or 'val', to construct train_subset00,test_subset01,val_subset02...
:return: all mhd file list
"""
mhd_files = []
for root, _, files in os.walk(mhd_root):
for file in files:
if file.lower().endswith('.mhd'):
real_file = os.path.join(mhd_root, root, file)
mhd_files.append(real_file)
return mhd_files
def read_csv_to_pandas(mhd_info,col_sepator ='\t'):
"""
read csv information into pandas dataframe
:param mhd_info: csv file of mhd file
:param col_sepator: sepator string of columns
:return:
"""
with open(mhd_info, 'r') as csv:
head = csv.readline().split(",") # get header of csv
indexs = []
lines = csv.readlines()
list = []
for line in lines:
list.append(line.split(col_sepator))
indexs.append(line.split(col_sepator)[0]) #the first element should be id of patient
df = pd.DataFrame(data=list, columns=head,index=indexs)
return df
def extract_image_from_mhd(mhd_file_path,png_save_path_root =None):
"""
extract image from mhd file and return mhd information
:param mhd_file_path: mhd file to extract
:param png_save_path_root: file path where to save the extracted image (both image and mask image will be saved)
,if this param is None means only mhd information returns,no image extracted
:return:
"""
mhd_info = []
patient_id = ntpath.basename(mhd_file_path).replace(".mhd", "")
print("Patient: ", patient_id)
mhd_info.append(patient_id)
if not os.path.exists(png_save_path_root):
os.mkdir(png_save_path_root)
dst_dir = png_save_path_root+'/' + patient_id + "/"
if not os.path.exists(dst_dir):
os.mkdir(dst_dir)
itk_img = SimpleITK.ReadImage(mhd_file_path)
img_array = SimpleITK.GetArrayFromImage(itk_img)
print("Img array: ", img_array.shape)
(shape0,shape1,shape2) = img_array.shape
mhd_info.append(str(shape2))
mhd_info.append(str(shape1))
mhd_info.append(str(shape0))
origin = np.array(itk_img.GetOrigin()) # x,y,z Origin in world coordinates (mm)
print("Origin (x,y,z): ", origin)
mhd_info.append(str(origin[0]))
mhd_info.append(str(origin[1]))
mhd_info.append(str(origin[2]))
direction = np.array(itk_img.GetDirection()) # x,y,z Origin in world coordinates (mm)
print("Direction: ", direction)
direct_arow = [1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0]
if direction.tolist() == direct_arow:
print("positive direction..")
mhd_info.append(str(1))
else:
mhd_info.append(str(-1))
spacing = np.array(itk_img.GetSpacing()) # spacing of voxels in world coor. (mm)
print("Spacing (x,y,z): ", spacing)
mhd_info.append(str(spacing[0]))
mhd_info.append(str(spacing[1]))
mhd_info.append(str(spacing[2]))
rescale = spacing /TARGET_VOXEL_MM
print("Rescale: ", rescale)
mhd_info.append(str(rescale[0]))
mhd_info.append(str(rescale[1]))
mhd_info.append(str(rescale[2]))
if png_save_path_root is None: # get mhd information only
return mhd_info
if not os.path.exists(dst_dir):
if img_array.shape[1]== 512:
img_array = image_util.rescale_patient_images(img_array, spacing, TARGET_VOXEL_MM)
img_list = []
for i in range(img_array.shape[0]):
img = img_array[i]
seg_img, mask = get_segmented_lungs(img.copy())
img_list.append(seg_img)
img = normalize_hu_values(img)
cv2.imwrite(dst_dir + "img_" + str(i).rjust(4, '0') + "_i.png", img * 255)
cv2.imwrite(dst_dir + "img_" + str(i).rjust(4, '0') + "_m.png", mask * 255)
return mhd_info
================================================
FILE: deploy/backend/util/seg_util.py
================================================
import numpy as np
import matplotlib.pyplot as plt
from scipy import ndimage as ndi
from skimage.filters import roberts
from skimage.measure import regionprops, label
from skimage.morphology import binary_closing, disk, binary_erosion
from skimage.segmentation import clear_border
def normalize_hu_values(image: np.ndarray, min_bound: int = -1000, max_bound: int = 400) -> np.ndarray:
"""
归一化HU值到[0,1]范围
Args:
image: 输入图像
min_bound: 最小HU值
max_bound: 最大HU值
Returns:
np.ndarray: 归一化后的图像
"""
image = (image - min_bound) / (max_bound - min_bound)
image[image > 1] = 1.
image[image < 0] = 0.
return image
def get_segmented_lungs(im, plot=False):
'''
extract lung ROI from pixel array
:param im: a patient's piexl array
:param plot: if plot when segment
:return:
'''
# Step 1: Convert into a binary image.
binary = im < -400
# Step 2: Remove the blobs connected to the border of the image.
cleared = clear_border(binary)
# Step 3: Label the image.
label_image = label(cleared)
# Step 4: Keep the labels with 2 largest areas.
areas = [r.area for r in regionprops(label_image)]
areas.sort()
if len(areas) > 2:
for region in regionprops(label_image):
if region.area < areas[-2]:
for coordinates in region.coords:
label_image[coordinates[0], coordinates[1]] = 0
binary = label_image > 0
# Step 5: Erosion operation with a disk of radius 2. This operation is seperate the lung nodules attached to the blood vessels.
selem = disk(2)
binary = binary_erosion(binary, selem)
# Step 6: Closure operation with a disk of radius 10. This operation is to keep nodules attached to the lung wall.
selem = disk(10) # CHANGE BACK TO 10
binary = binary_closing(binary, selem)
# Step 7: Fill in the small holes inside the binary mask of lungs.
edges = roberts(binary)
binary = ndi.binary_fill_holes(edges)
# Step 8: Superimpose the binary mask on the input image.
get_high_vals = binary == 0
im[get_high_vals] = -2000
if plot:
plt.figure(figsize=(10, 10))
plt.subplot(1, 2, 1)
plt.imshow(binary, cmap='gray')
plt.title('Lung Mask')
plt.subplot(1, 2, 2)
plt.imshow(im, cmap='gray')
plt.title('Masked Image')
plt.show()
return im, binary
================================================
FILE: deploy/backend/utils.py
================================================
import os
import sys
import logging
import numpy as np
import SimpleITK as sitk
import tempfile
import zipfile
import shutil
import pydicom
from scipy import ndimage
# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# 尝试导入CTData类,如果不可用则创建一个简化版
try:
# 添加项目根目录到系统路径
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
# 导入原始CTData类
from data.dataclass.CTData import CTData
except ImportError:
# 如果无法导入,定义一个简化版的CTData类
class CTData:
"""CT数据类,用于加载和处理CT图像"""
def __init__(self):
self.img = None # 原始图像数据
self.origin = None # 图像原点
self.spacing = None # 图像间距
self.is_normalized = False # 是否已标准化
self.has_lung_seg = False # 是否已进行肺部分割
self.lung_seg_img = None # 肺部分割图像
self.lung_mask = None # 肺部掩码
@classmethod
def from_dicom(cls, dicom_path):
"""从DICOM文件夹加载CT数据"""
ct_data = cls()
try:
logger.info(f"从DICOM加载: {dicom_path}")
# 处理DICOM目录或zip文件
temp_dir = None
if os.path.isfile(dicom_path) and dicom_path.endswith('.zip'):
# 创建临时目录解压缩
temp_dir = tempfile.mkdtemp()
with zipfile.ZipFile(dicom_path, 'r') as zip_ref:
zip_ref.extractall(temp_dir)
dicom_path = temp_dir
# 读取DICOM序列
reader = sitk.ImageSeriesReader()
dicom_names = reader.GetGDCMSeriesFileNames(dicom_path)
if not dicom_names:
raise ValueError(f"在{dicom_path}中未找到DICOM文件")
reader.SetFileNames(dicom_names)
ct_sitk_img = reader.Execute()
# 清理临时目录
if temp_dir and os.path.exists(temp_dir):
shutil.rmtree(temp_dir)
# 转换为numpy数组
ct_data.img = sitk.GetArrayFromImage(ct_sitk_img)
ct_data.origin = ct_sitk_img.GetOrigin()
ct_data.spacing = ct_sitk_img.GetSpacing()
# 转换为HU单位
ct_data.convert_to_hu()
return ct_data
except Exception as e:
logger.error(f"从DICOM加载失败: {e}")
# 确保临时目录被清理
if temp_dir and os.path.exists(temp_dir):
shutil.rmtree(temp_dir)
raise
@classmethod
def from_mhd(cls, mhd_path):
"""从MHD文件加载CT数据"""
ct_data = cls()
try:
logger.info(f"从MHD加载: {mhd_path}")
# 读取MHD文件
ct_sitk_img = sitk.ReadImage(mhd_path)
# 转换为numpy数组
ct_data.img = sitk.GetArrayFromImage(ct_sitk_img)
ct_data.origin = ct_sitk_img.GetOrigin()
ct_data.spacing = ct_sitk_img.GetSpacing()
# 转换为HU单位 (假设MHD已经是HU单位)
# 如果不是,可以取消下面的注释
# ct_data.convert_to_hu()
return ct_data
except Exception as e:
logger.error(f"从MHD加载失败: {e}")
raise
def resample_pixels(self, new_spacing=[1.0, 1.0, 1.0]):
"""将像素重采样到指定间距"""
if self.img is None:
logger.error("没有图像数据可重采样")
return
logger.info(f"重采样像素,原始间距: {self.spacing},目标间距: {new_spacing}")
# 计算调整后的大小
resize_factor = np.array(self.spacing) / np.array(new_spacing)
new_real_shape = self.img.shape * resize_factor
new_shape = np.round(new_real_shape).astype(np.int32)
# 计算用于调整大小的实际调整因子
real_resize_factor = new_shape / self.img.shape
real_new_spacing = np.array(self.spacing) / real_resize_factor
# 使用sitk进行重采样
sitk_img = sitk.GetImageFromArray(self.img)
sitk_img.SetSpacing(self.spacing)
resample = sitk.ResampleImageFilter()
resample.SetInterpolator(sitk.sitkLinear)
resample.SetOutputSpacing(real_new_spacing)
resample.SetSize(new_shape.tolist())
resample.SetOutputDirection(sitk_img.GetDirection())
resample.SetOutputOrigin(sitk_img.GetOrigin())
resampled_img = resample.Execute(sitk_img)
self.img = sitk.GetArrayFromImage(resampled_img)
self.spacing = real_new_spacing
return self
def convert_to_hu(self):
"""将图像转换为HU单位"""
if self.img is None:
logger.error("没有图像数据可转换")
return
if self.is_normalized:
logger.info("图像已经转换为HU单位")
return
logger.info("转换图像为HU单位")
# 对于DICOM,通常需要转换为HU单位
# 但这个实现假设数据已经是HU或类似单位
# 如果需要,这里可以添加特定的转换逻辑
self.is_normalized = True
return self
def filter_lung_img_mask(self, threshold=-320):
"""提取肺部区域图像和掩码"""
if self.img is None:
logger.error("没有图像数据可分割")
return
logger.info("分割肺部区域")
# 确保图像已转换为HU单位
if not self.is_normalized:
self.convert_to_hu()
# 创建阈值掩码
threshold_image = np.copy(self.img)
threshold_image[threshold_image < threshold] = 1
threshold_image[threshold_image >= threshold] = 0
# 获取与身体连接的区域
from scipy import ndimage as ndi
# 填充身体外部的空气
mask = self.fill_body_mask(threshold_image)
# 反转掩码以获取身体内的空气区域
lung_mask = np.logical_xor(threshold_image, mask)
# 移除小连接区域
struct = np.ones((2, 2, 2), dtype=np.bool_)
lung_mask = ndi.binary_opening(lung_mask, structure=struct, iterations=2)
labeled_mask, num_features = ndi.label(lung_mask)
# 只保留最大的两个连接区域(肺)
if num_features > 2:
areas = [np.sum(labeled_mask == i) for i in range(1, num_features + 1)]
labels = np.argsort(areas)[-2:] + 1
# 创建新的肺掩码
lung_mask = np.zeros_like(labeled_mask, dtype=bool)
for label in labels:
lung_mask = lung_mask | (labeled_mask == label)
# 保存结果
self.lung_mask = lung_mask
self.lung_seg_img = self.img * lung_mask
self.has_lung_seg = True
return self
def fill_body_mask(self, threshold_image):
"""填充身体外部空气区域"""
# 创建边界种子
mask = np.zeros_like(threshold_image, dtype=bool)
mask[0, :, :] = True
mask[-1, :, :] = True
mask[:, 0, :] = True
mask[:, -1, :] = True
mask[:, :, 0] = True
mask[:, :, -1] = True
# 找到与边界连接的所有区域(即外部空气)
from scipy import ndimage as ndi
mask = ndi.binary_dilation(mask, structure=np.ones((3, 3, 3)), iterations=1)
mask = np.logical_and(mask, threshold_image > 0)
mask = ndi.binary_fill_holes(mask)
return mask
def extract_lung_from_image(sitk_image):
"""
从SimpleITK图像中提取肺部区域
:param sitk_image: SimpleITK CT图像
:return: 肺部分割掩码 (numpy数组)
"""
# 转换为numpy数组
ct_array = sitk.GetArrayFromImage(sitk_image)
# 获取图像信息
spacing = sitk_image.GetSpacing()
origin = sitk_image.GetOrigin()
direction = sitk_image.GetDirection()
# 进行肺部分割
lung_mask = segment_lung(ct_array)
return lung_mask
def segment_lung(ct_array):
"""
简单的肺部分割算法
:param ct_array: CT图像数组 [z, y, x]
:return: 肺部掩码
"""
# 1. 阈值处理 - 肺部通常是空气(-1000HU)到组织(-500HU)之间
binary_image = np.logical_and(ct_array > -1000, ct_array < -500)
# 2. 对每个切片进行处理
result_mask = np.zeros_like(binary_image, dtype=bool)
for i in range(binary_image.shape[0]):
# 获取当前切片
slice_img = binary_image[i].copy()
# 填充边界以确保背景是连通的
slice_img[0,:] = 1
slice_img[-1,:] = 1
slice_img[:,0] = 1
slice_img[:,-1] = 1
# 标记切片中的连通区域
labeled, num_labels = ndimage.label(slice_img)
# 按大小排序区域
regions = ndimage.find_objects(labeled)
region_sizes = [(i+1, (region[0].stop - region[0].start) * (region[1].stop - region[1].start))
for i, region in enumerate(regions)]
region_sizes.sort(key=lambda x: x[1], reverse=True)
# 第一个最大的区域通常是背景或身体,我们需要肺部区域
lung_mask = np.zeros_like(slice_img, dtype=bool)
# 找出前5个最大区域(通常背景为最大,可能左右肺为2、3大区域)
for j in range(min(5, len(region_sizes))):
# 排除最大区域(通常是背景或身体)
if j == 0:
continue
region_idx = region_sizes[j][0]
# 添加这个区域到肺部掩码
lung_mask[labeled == region_idx] = True
# 形态学操作以填充肺部内的孔洞(例如血管)
lung_mask = ndimage.binary_closing(lung_mask, structure=np.ones((5,5)))
lung_mask = ndimage.binary_fill_holes(lung_mask)
# 保存到结果
result_mask[i] = lung_mask
# 确保切片之间的连续性
result_mask = ndimage.binary_closing(result_mask, structure=np.ones((3,3,3)))
return result_mask.astype(np.uint8)
def extract_lung_from_file(file_path):
"""
从文件中提取肺部区域
支持MHD和DICOM格式
:param file_path: 文件路径
:return: 肺部数据字典
"""
# 根据文件类型加载图像
if file_path.lower().endswith('.mhd'):
# 加载MHD文件
ct_image = sitk.ReadImage(file_path)
elif file_path.lower().endswith(('.dcm', '.dicom')):
# 加载单个DICOM文件
ct_image = sitk.ReadImage(file_path)
elif os.path.isdir(file_path):
# 加载DICOM系列
reader = sitk.ImageSeriesReader()
dicom_names = reader.GetGDCMSeriesFileNames(file_path)
reader.SetFileNames(dicom_names)
ct_image = reader.Execute()
else:
raise ValueError(f"不支持的文件类型: {file_path}")
# 提取肺部
lung_mask = extract_lung_from_image(ct_image)
# 应用掩码到原始CT
ct_array = sitk.GetArrayFromImage(ct_image)
lung_data = ct_array.copy()
lung_data[~lung_mask.astype(bool)] = -2000 # 设置非肺部区域为-2000HU
# 构建返回数据
return {
'lung_data': lung_data,
'lung_mask': lung_mask,
'original_ct': ct_array,
'spacing': ct_image.GetSpacing(),
'origin': ct_image.GetOrigin(),
'direction': ct_image.GetDirection()
}
def prepare_data_for_3d_rendering(lung_data):
"""
将肺部数据转换为适合3D渲染的格式
:param lung_data: 肺部数据字典
:return: 渲染数据
"""
# 获取肺部掩码的轮廓
mask = lung_data['lung_mask']
# 降采样以减少数据量
downsampled_mask = mask[::4, ::4, ::4]
# 获取表面点(简单方法:获取非零点坐标)
points = np.argwhere(downsampled_mask > 0).tolist()
# 计算边界框
if len(points) > 0:
points_array = np.array(points)
min_bounds = points_array.min(axis=0).tolist()
max_bounds = points_array.max(axis=0).tolist()
else:
min_bounds = [0, 0, 0]
max_bounds = list(downsampled_mask.shape)
# 构建渲染数据
render_data = {
'points': points,
'dimensions': downsampled_mask.shape,
'min_bounds': min_bounds,
'max_bounds': max_bounds
}
return render_data
================================================
FILE: deploy/frontend/css/style.css
================================================
/* 主要样式表 */
/* 重置和全局样式 */
* {
margin: 0;
padding: 0;
box-sizing: border-box;
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
}
body {
background-color: #f5f7fa;
color: #333;
line-height: 1.6;
}
.container {
display: flex;
min-height: 100vh;
width: 100%;
}
/* 侧边栏样式 */
.sidebar {
width: 300px;
background-color: #2c3e50;
color: #ecf0f1;
padding: 20px;
overflow-y: auto;
box-shadow: 2px 0 5px rgba(0, 0, 0, 0.1);
display: flex;
flex-direction: column;
}
.logo {
text-align: center;
padding: 10px 0 20px;
border-bottom: 1px solid #34495e;
}
.logo h1 {
font-size: 1.4rem;
color: #3498db;
}
/* 步骤导航 */
.steps {
margin: 20px 0;
}
.step {
display: flex;
padding: 15px 10px;
margin-bottom: 10px;
border-radius: 5px;
background-color: #34495e;
opacity: 0.7;
transition: all 0.3s ease;
}
.step.active {
background-color: #3498db;
opacity: 1;
box-shadow: 0 2px 5px rgba(0, 0, 0, 0.2);
}
.step.complete {
background-color: #27ae60;
opacity: 1;
}
.step-icon {
width: 40px;
height: 40px;
border-radius: 50%;
background-color: rgba(255, 255, 255, 0.2);
display: flex;
align-items: center;
justify-content: center;
margin-right: 15px;
}
.step-icon i {
font-size: 1.2rem;
}
.step-content h3 {
font-size: 1rem;
margin-bottom: 5px;
}
.step-content p {
font-size: 0.8rem;
opacity: 0.8;
}
/* 上传区域 */
.upload-area {
background-color: #34495e;
padding: 15px;
border-radius: 5px;
margin-bottom: 20px;
}
.upload-area h3 {
margin-bottom: 10px;
font-size: 1rem;
}
.upload-area p {
font-size: 0.8rem;
opacity: 0.8;
margin-bottom: 15px;
}
.file-input-container {
margin-bottom: 15px;
}
input[type="file"] {
display: none;
}
.select-file-btn {
background-color: #3498db;
color: white;
border: none;
padding: 10px 15px;
border-radius: 5px;
cursor: pointer;
width: 100%;
margin-bottom: 10px;
font-size: 0.9rem;
transition: background-color 0.3s;
}
.select-file-btn:hover {
background-color: #2980b9;
}
#file-info {
font-size: 0.8rem;
padding: 5px;
background-color: rgba(255, 255, 255, 0.1);
border-radius: 3px;
margin-top: 5px;
word-break: break-all;
}
.patient-info-input {
margin-bottom: 15px;
}
.patient-info-input label {
display: block;
font-size: 0.8rem;
margin-bottom: 5px;
}
.patient-info-input input {
width: 100%;
padding: 8px;
border-radius: 3px;
border: 1px solid #546e7a;
background-color: rgba(255, 255, 255, 0.1);
color: white;
}
.upload-btn {
background-color: #e74c3c;
color: white;
border: none;
padding: 10px 15px;
border-radius: 5px;
cursor: pointer;
width: 100%;
font-size: 0.9rem;
transition: background-color 0.3s;
}
.upload-btn:hover {
background-color: #c0392b;
}
.upload-btn:disabled {
background-color: #7f8c8d;
cursor: not-allowed;
}
/* 进度区域 */
.progress-area {
background-color: #34495e;
padding: 15px;
border-radius: 5px;
margin-bottom: 20px;
}
.progress-area h3 {
margin-bottom: 10px;
font-size: 1rem;
}
.progress-container {
height: 20px;
background-color: rgba(255, 255, 255, 0.1);
border-radius: 10px;
overflow: hidden;
position: relative;
margin-bottom: 10px;
}
.progress-bar {
height: 100%;
background-color: #2ecc71;
width: 0%;
transition: width 0.3s ease;
border-radius: 10px;
}
.progress-text {
position: absolute;
top: 0;
left: 0;
right: 0;
bottom: 0;
display: flex;
align-items: center;
justify-content: center;
font-size: 0.8rem;
color: white;
}
.progress-message {
font-size: 0.8rem;
margin: 10px 0;
min-height: 40px;
}
.action-buttons {
display: flex;
justify-content: center;
margin-top: 15px;
}
.btn {
background-color: #3498db;
color: white;
border: none;
padding: 8px 15px;
border-radius: 5px;
cursor: pointer;
font-size: 0.9rem;
transition: background-color 0.3s;
}
.btn:hover {
background-color: #2980b9;
}
.btn:disabled {
background-color: #7f8c8d;
cursor: not-allowed;
}
/* 患者信息区域 */
.patient-info-area {
margin-top: auto;
padding: 10px;
font-size: 0.8rem;
background-color: rgba(255, 255, 255, 0.1);
border-radius: 5px;
}
/* 主内容区域 */
.main-content {
flex: 1;
display: flex;
flex-direction: row;
overflow: hidden;
position: relative;
}
/* 查看器容器 */
.viewer-container {
flex: 0.7; /* 占据主区域的70% */
height: 100vh;
padding: 15px;
overflow: hidden;
position: relative;
}
/* CT切片查看器 */
#ct-viewer {
background-color: #f5f5f5;
width: 100%;
height: 100%;
min-height: 500px; /* 确保有足够高度显示切片 */
position: relative;
border-radius: 5px;
overflow: auto; /* 允许内容溢出时滚动 */
box-shadow: 0 2px 8px rgba(0,0,0,0.15);
display: flex; /* 使用flex布局 */
flex-direction: column; /* 垂直排列内容 */
justify-content: flex-start; /* 从上往下排列 */
align-items: center; /* 水平居中 */
padding: 10px; /* 增加内边距 */
}
/* 查看器控制区 */
.viewer-controls {
position: absolute;
bottom: 25px;
right: 25px;
z-index: 10;
}
/* 结果区域 */
.results-panel {
flex: 0.25; /* 占据主区域的25% */
height: 100vh;
padding: 20px;
background-color: white;
box-shadow: -2px 0 5px rgba(0, 0, 0, 0.1);
display: flex;
flex-direction: column;
overflow-y: auto;
}
.results-header {
display: flex;
justify-content: space-between;
align-items: center;
margin-bottom: 20px;
padding-bottom: 10px;
border-bottom: 1px solid #e0e0e0;
}
.results-header h2 {
font-size: 1.2rem;
color: #2c3e50;
}
.summary {
font-size: 0.9rem;
color: #7f8c8d;
}
/* 结果布局 */
.results-layout {
display: flex;
flex-direction: column;
gap: 20px;
height: 100%;
min-height: 200px;
}
/* 结节列表区域 */
.nodule-list-section {
flex: 1;
max-height: 300px;
}
.nodule-list-section h3 {
margin-bottom: 10px;
color: #2c3e50;
font-size: 1rem;
}
#nodule-list {
margin-top: 10px;
overflow-y: auto;
max-height: 250px;
border: 1px solid #eee;
border-radius: 5px;
}
.nodule-item {
background-color: #f5f5f5;
border-radius: 5px;
padding: 10px;
margin-bottom: 10px;
cursor: pointer;
border-left: 4px solid #ccc;
transition: all 0.2s ease;
}
.nodule-item:hover {
background-color: #e9e9e9;
transform: translateX(2px);
}
.nodule-item.selected {
background-color: #e1f5fe;
border-left-color: #2196F3;
}
/* 结节详情区域 */
.nodule-detail-section {
flex: 2;
}
#nodule-detail-area {
background-color: #f5f5f5;
border-radius: 5px;
padding: 15px;
height: 100%;
box-shadow: 0 1px 3px rgba(0,0,0,0.1);
}
.no-selection {
display: flex;
align-items: center;
justify-content: center;
height: 100%;
min-height: 150px;
color: #999;
font-style: italic;
}
/* 没有结节的提示 */
.no-nodules {
text-align: center;
padding: 20px;
color: #666;
font-style: italic;
background-color: #f9f9f9;
border-radius: 5px;
}
/* 预览区域 */
.nodule-preview {
margin-top: 20px;
}
.nodule-preview h4 {
font-size: 14px;
margin-bottom: 10px;
}
/* 预览视图样式 */
.preview-views {
display: flex;
flex-wrap: wrap;
justify-content: space-between;
margin-top: 15px;
}
.preview-view {
width: 31%;
margin-bottom: 15px;
background-color: #f9f9f9;
border-radius: 5px;
overflow: hidden;
box-shadow: 0 1px 3px rgba(0,0,0,0.1);
}
.preview-view h5 {
margin: 0;
padding: 8px;
font-size: 14px;
background-color: #eaeaea;
text-align: center;
}
.preview-img {
height: 150px;
display: flex;
align-items: center;
justify-content: center;
position: relative;
overflow: hidden;
}
.preview-img img {
max-width: 100%;
max-height: 100%;
display: block;
border: none;
}
/* 加载中状态显示 */
.loading-spinner.small {
padding: 10px;
}
.loading-spinner.small .spinner {
width: 25px;
height: 25px;
border-width: 3px;
}
.loading-spinner.small p {
font-size: 12px;
margin-top: 5px;
}
/* 错误消息样式 */
.error-message.small {
font-size: 12px;
padding: 10px;
margin: 5px;
}
@media screen and (max-width: 768px) {
.preview-view {
width: 100%;
margin-bottom: 10px;
}
}
/* 消息提示 */
.message-container {
position: fixed;
top: 20px;
right: 20px;
width: 300px;
z-index: 1000;
}
.message {
padding: 15px;
margin-bottom: 10px;
border-radius: 5px;
box-shadow: 0 2px 10px rgba(0, 0, 0, 0.2);
animation: slideIn 0.3s ease;
}
.message.info {
background-color: #3498db;
color: white;
}
.message.success {
background-color: #2ecc71;
color: white;
}
.message.warning {
background-color: #f39c12;
color: white;
}
.message.error {
background-color: #e74c3c;
color: white;
}
.message.fade-out {
animation: fadeOut 0.5s ease forwards;
}
@keyframes slideIn {
from {
transform: translateX(100%);
opacity: 0;
}
to {
transform: translateX(0);
opacity: 1;
}
}
@keyframes fadeOut {
from {
opacity: 1;
}
to {
opacity: 0;
}
}
/* 响应式调整 */
@media (max-width: 768px) {
.container {
flex-direction: column;
}
.sidebar {
width: 100%;
max-height: 300px;
}
.main-content {
height: calc(100vh - 300px);
}
.ct-viewer-container {
height: 40%;
}
}
/* 结节项样式 */
.nodule-header {
display: flex;
justify-content: space-between;
align-items: center;
margin-bottom: 8px;
}
.nodule-name {
font-weight: 600;
font-size: 14px;
}
.nodule-probability {
padding: 3px 6px;
border-radius: 10px;
font-size: 12px;
font-weight: bold;
}
.nodule-probability.high {
background-color: #ffebee;
color: #d32f2f;
}
.nodule-probability.medium {
background-color: #fff8e1;
color: #ff8f00;
}
.nodule-probability.low {
background-color: #e8f5e9;
color: #388e3c;
}
.nodule-details {
font-size: 12px;
color: #666;
}
.nodule-detail {
margin-bottom: 4px;
}
.detail-label {
color: #777;
margin-right: 5px;
}
/* 详细信息项样式 */
.detail-item {
margin-bottom: 8px;
display: flex;
}
.detail-label {
width: 90px;
color: #666;
}
.detail-value {
flex: 1;
font-weight: 500;
}
.detail-value.high {
color: #d32f2f;
}
.detail-value.medium {
color: #ff8f00;
}
.detail-value.low {
color: #388e3c;
}
/* 加载状态 */
.loading-spinner {
display: flex;
flex-direction: column;
align-items: center;
justify-content: center;
padding: 30px;
}
.spinner {
width: 30px;
height: 30px;
border: 3px solid #f3f3f3;
border-top: 3px solid #2196F3;
border-radius: 50%;
animation: spin 1s linear infinite;
margin-bottom: 10px;
}
@keyframes spin {
0% { transform: rotate(0deg); }
100% { transform: rotate(360deg); }
}
.error-message {
color: #d32f2f;
padding: 15px;
text-align: center;
background-color: #ffebee;
border-radius: 4px;
}
/* CT切片视图样式 */
.slice-views-container {
display: flex;
flex-wrap: wrap;
justify-content: space-between;
width: 100%;
height: 100%;
min-height: 450px; /* 确保有足够高度显示切片 */
padding: 10px;
box-sizing: border-box;
}
.slice-view-container {
width: calc(33.33% - 10px);
height: 450px; /* 固定高度 */
min-height: 300px;
display: flex;
flex-direction: column;
background-color: #222;
border-radius: 5px;
overflow: hidden;
box-shadow: 0 2px 8px rgba(0, 0, 0, 0.2);
margin-bottom: 15px;
}
.slice-view-header {
background-color: #333;
color: white;
padding: 8px 12px;
font-weight: bold;
font-size: 14px;
text-align: center;
}
.slice-view {
flex: 1;
display: flex;
align-items: center;
justify-content: center;
overflow: hidden;
background-color: #000;
position: relative;
min-height: 350px; /* 确保切片视图区域有足够高度 */
}
.slice-view img {
max-width: 100%;
max-height: 100%;
object-fit: contain;
display: block;
border: 1px solid #444; /* 添加边框以便更容易看到图像边界 */
}
.slice-view-controls {
display: flex;
align-items: center;
justify-content: center;
padding: 8px;
background-color: #333;
}
.slice-btn {
background-color: #444;
color: white;
border: none;
border-radius: 3px;
width: 30px;
height: 30px;
display: flex;
align-items: center;
justify-content: center;
cursor: pointer;
transition: background-color 0.2s;
}
.slice-btn:hover {
background-color: #666;
}
.slice-btn:active {
background-color: #555;
}
.slice-index {
color: white;
margin: 0 15px;
font-size: 14px;
min-width: 60px;
text-align: center;
}
/* 响应式调整 */
@media (max-width: 1200px) {
.slice-view-container {
width: calc(50% - 10px);
}
}
@media (max-width: 768px) {
.slice-view-container {
width: 100%;
}
}
================================================
FILE: deploy/frontend/index.html
================================================
CT 肺结节检测系统
================================================
FILE: deploy/frontend/js/main.js
================================================
/**
* CT 肺结节检测系统 主JS文件
*/
// 全局变量
let currentSessionId = null;
let currentSliceIndex = 0;
let maxSliceIndex = 0;
let lungSegmentationLoaded = false;
let progressInterval = null;
// 页面加载完成后初始化
document.addEventListener('DOMContentLoaded', function() {
console.log('CT肺结节检测系统初始化');
// 初始化事件监听器
initializeEventListeners();
// 初始化步骤UI
initializeSteps();
// 设置文件上传
setupFileUpload();
// 检查是否有会话ID保存在sessionStorage中
const savedSessionId = sessionStorage.getItem('currentSessionId');
});
// 初始化步骤显示
function initializeSteps() {
// 初始状态下激活上传步骤
updateUIState('initial');
}
// 设置文件上传
function setupFileUpload() {
const fileInput = document.getElementById('ct-file');
const selectFileBtn = document.getElementById('select-file-btn');
const uploadBtn = document.getElementById('upload-btn');
const fileInfo = document.getElementById('file-info');
// 选择文件按钮点击事件
selectFileBtn.addEventListener('click', function() {
fileInput.click();
});
// 文件选择变化事件
fileInput.addEventListener('change', function() {
if (this.files.length > 0) {
const file = this.files[0];
// 检查文件类型
const fileExt = getFileExtension(file.name).toLowerCase();
if (!['zip', 'mhd', 'dcm', 'dicom', 'nii', 'nii.gz'].includes(fileExt)) {
showMessage('请上传支持的格式: MHD, DICOM, NIfTI 或 ZIP文件', 'error');
fileInfo.textContent = '请选择支持的CT文件格式';
uploadBtn.disabled = true;
return;
}
fileInfo.textContent = `已选择: ${file.name} (${formatFileSize(file.size)})`;
uploadBtn.disabled = false;
} else {
fileInfo.textContent = '请选择CT文件';
uploadBtn.disabled = true;
}
});
// 上传按钮点击事件
uploadBtn.addEventListener('click', function() {
if (fileInput.files.length === 0) {
showMessage('请先选择文件', 'error');
return;
}
// 上传文件
uploadFile(fileInput.files[0]);
});
}
// 初始化事件监听器
function initializeEventListeners() {
// 重置视图按钮
document.getElementById('reset-view-btn').addEventListener('click', function() {
if (maxSliceIndex > 0) {
// 将切片重置到中间位置
currentSliceIndex = Math.floor(maxSliceIndex / 2);
loadSlice(currentSliceIndex);
showMessage('已重置视图至中间切片', 'info');
} else {
showMessage('无法重置视图,未加载数据', 'warning');
}
});
// 开始检测按钮
document.getElementById('start-detection-btn').addEventListener('click', startDetection);
// 结节列表点击事件委托
document.getElementById('nodule-list').addEventListener('click', function(e) {
if (e.target.closest('.nodule-item')) {
const noduleItem = e.target.closest('.nodule-item');
const noduleId = noduleItem.dataset.id;
// 移除其他项目的选中状态
document.querySelectorAll('.nodule-item').forEach(item => {
item.classList.remove('selected');
});
// 添加选中状态
noduleItem.classList.add('selected');
// 加载结节详情
console.log(`选中结节 ${noduleId}`);
loadNoduleDetails(noduleId);
}
});
}
// 获取文件扩展名
function getFileExtension(filename) {
return filename.split('.').pop();
}
// 格式化文件大小
function formatFileSize(bytes) {
if (bytes === 0) return '0 Bytes';
const k = 1024;
const sizes = ['Bytes', 'KB', 'MB', 'GB'];
const i = Math.floor(Math.log(bytes) / Math.log(k));
return parseFloat((bytes / Math.pow(k, i)).toFixed(2)) + ' ' + sizes[i];
}
// 上传文件
function uploadFile(file) {
console.log('开始上传文件:', file.name);
// 显示进度区域
document.getElementById('upload-area').style.display = 'none';
document.getElementById('progress-area').style.display = 'block';
// 更新进度条初始状态
updateProgress(10, '正在上传文件...');
// 创建FormData对象
const formData = new FormData();
formData.append('file', file);
// 添加患者ID (如果有)
const patientId = document.getElementById('patient-id').value;
if (patientId) {
formData.append('patient_id', patientId);
document.getElementById('patient-info').textContent = `患者ID: ${patientId}`;
}
// 发送上传请求
fetch('/api/upload', {
method: 'POST',
body: formData
})
.then(response => {
console.log('服务器响应状态:', response.status);
if (!response.ok) {
throw new Error(`网络响应错误: ${response.status}`);
}
return response.json();
})
.then(data => {
console.log('上传响应:', data);
if (data.success) {
// 保存会话ID
currentSessionId = data.session_id;
console.log('保存会话ID:', currentSessionId);
// 同时保存到sessionStorage
sessionStorage.setItem('currentSessionId', currentSessionId);
// 完成上传步骤
updateUIState('uploaded');
// 更新进度
updateProgress(100, '文件上传成功');
// 显示文件类型信息
let typeInfo = data.file_type ? `检测到${data.file_type}格式CT数据` : '文件已上传';
// 根据是否自动检测显示不同消息
if (data.auto_detect) {
document.getElementById('progress-message').textContent = `${typeInfo},正在自动开始检测...`;
document.getElementById('start-detection-btn').disabled = true;
// 开始轮询检测进度
startProgressPolling();
} else {
document.getElementById('progress-message').textContent = `${typeInfo},可以开始检测`;
document.getElementById('start-detection-btn').disabled = false;
}
// 显示成功消息
showMessage('文件上传成功', 'success');
} else {
throw new Error(data.error || '上传失败');
}
})
.catch(error => {
console.error('文件上传错误:', error);
updateUIState('initial');
document.getElementById('progress-message').textContent = `上传失败: ${error.message}`;
showMessage(`上传失败: ${error.message}`, 'error');
// 重新显示上传区域
setTimeout(() => {
document.getElementById('progress-area').style.display = 'none';
document.getElementById('upload-area').style.display = 'block';
}, 3000);
});
}
// 开始检测
async function startDetection() {
if (!currentSessionId) {
showMessage('请先上传CT文件', 'error');
return;
}
try {
// 更新UI状态
updateUIState('detecting');
// 清除之前的数据
resetResults();
// 显示消息
showMessage('正在开始检测...', 'info');
// 发送检测请求
const response = await fetch('/api/detect', {
method: 'POST',
headers: {
'Content-Type': 'application/json'
},
body: JSON.stringify({
session_id: currentSessionId
})
});
const data = await response.json();
if (!data.success) {
throw new Error(data.error || '启动检测失败');
}
// 显示成功消息
showMessage('检测已启动', 'success');
// 开始轮询检测进度
startProgressPolling();
} catch (error) {
console.error('启动检测时出错:', error);
showMessage(`启动检测失败: ${error.message}`, 'error');
updateUIState('uploaded');
}
}
// 开始轮询检测进度
function startProgressPolling() {
// 清除之前的轮询
if (progressInterval) {
clearInterval(progressInterval);
}
// 初始化进度条
updateProgress(0, '正在初始化检测...');
// 显示进度区域
document.getElementById('progress-area').style.display = 'block';
// 重置肺部分割加载状态
lungSegmentationLoaded = false;
// 开始轮询
progressInterval = setInterval(async function() {
try {
if (!currentSessionId) {
clearInterval(progressInterval);
return;
}
const response = await fetch(`/api/progress/${currentSessionId}`);
const data = await response.json();
if (!data.success) {
throw new Error(data.error || '获取进度失败');
}
// 更新进度条
updateProgress(data.progress, data.message);
// 当进度达到40%时,可以开始尝试加载肺部切片数据
if (data.progress >= 40 && !lungSegmentationLoaded) {
loadLungSegmentation();
}
// 如果状态为已完成或错误,停止轮询
if (data.status === 'completed') {
clearInterval(progressInterval);
progressInterval = null;
// 获取检测结果
fetchResults();
} else if (data.status === 'error') {
clearInterval(progressInterval);
progressInterval = null;
showMessage(`检测失败: ${data.error || '未知错误'}`, 'error');
updateUIState('uploaded');
}
} catch (error) {
console.error('获取进度时出错:', error);
showMessage(`获取进度失败: ${error.message}`, 'error');
}
}, 1000); // 每秒轮询一次
}
// 更新进度条
function updateProgress(progress, message) {
const progressBar = document.getElementById('progress-bar');
const progressText = document.getElementById('progress-text');
const progressMessage = document.getElementById('progress-message');
if (progressBar && progressText) {
progressBar.style.width = `${progress}%`;
progressText.textContent = `${progress}%`;
}
if (progressMessage && message) {
progressMessage.textContent = message;
}
}
// 加载肺部切片数据
function loadLungSegmentation() {
console.log('加载肺部切片数据...');
// 检查会话ID是否存在
if (!currentSessionId) {
console.error('无法加载肺部切片: 没有会话ID');
return;
}
// 显示加载中消息
document.getElementById('progress-message').textContent = '加载肺部切片数据...';
// 加载切片信息
fetch(`/api/lung_slices_info/${currentSessionId}`)
.then(response => {
console.log('请求lung sliced info 返回了什么\t', response);
if (!response.ok) {
throw new Error(`服务器返回错误状态: ${response.status}`);
}
return response.json();
})
.then(data => {
if (!data.success) {
throw new Error(data.error || '获取切片信息失败');
}
console.log('获取到切片信息:', data.slices_info);
// 设置全局变量
maxSliceIndex = data.slices_info.z_slices - 1;
// 将当前切片设置为中间位置
currentSliceIndex = Math.floor(maxSliceIndex / 2);
// 初始化CT查看器
initCTViewer();
// 加载初始切片
loadSlice(currentSliceIndex);
// 设置状态
document.getElementById('progress-message').textContent = '肺部切片数据加载完成';
lungSegmentationLoaded = true;
})
.catch(error => {
console.error('获取肺部切片信息失败:', error);
document.getElementById('progress-message').textContent = `加载失败: ${error.message}`;
showMessage(`肺部切片数据加载失败: ${error.message}`, 'error');
});
}
// 初始化CT查看器
function initCTViewer() {
console.log('初始化CT查看器...');
// 获取查看器容器
const viewerContainer = document.getElementById('ct-viewer');
if (!viewerContainer) {
console.error('找不到CT查看器容器元素');
return;
}
// 清空容器
viewerContainer.innerHTML = '';
// 创建CT切片视图 - 只显示Z轴切片
const sliceView = document.createElement('div');
sliceView.className = 'slice-view';
sliceView.id = 'main-slice-view';
sliceView.innerHTML = `
0 / 0
`;
viewerContainer.appendChild(sliceView);
// 添加滚动事件监听器
viewerContainer.addEventListener('wheel', handleSliceScroll);
// 添加按钮事件监听器
document.getElementById('prev-btn').addEventListener('click', () => changeSlice(-1));
document.getElementById('next-btn').addEventListener('click', () => changeSlice(1));
}
// 处理滚轮事件
function handleSliceScroll(event) {
event.preventDefault();
// 确定滚动方向
const delta = Math.sign(event.deltaY);
// 切换切片
changeSlice(delta);
}
// 切换切片
function changeSlice(delta) {
// 计算新的切片索引
let newIndex = currentSliceIndex + delta;
// 确保索引在有效范围内
newIndex = Math.max(0, Math.min(maxSliceIndex, newIndex));
// 如果索引没有变化,不更新
if (newIndex === currentSliceIndex) {
return;
}
// 更新当前索引
currentSliceIndex = newIndex;
// 加载新切片
loadSlice(currentSliceIndex);
}
// 加载切片
function loadSlice(sliceIndex) {
console.log(`加载Z轴切片,索引: ${sliceIndex}`);
// 获取图像元素
const imgElement = document.getElementById('slice-img');
if (!imgElement) {
console.error('找不到切片图像元素');
return;
}
// 更新切片索引显示
const indexElement = document.getElementById('slice-index');
if (indexElement) {
indexElement.textContent = `${sliceIndex} / ${maxSliceIndex}`;
}
// 构建API URL
const url = `/api/lung_slice/${currentSessionId}/z/${sliceIndex}`;
// 添加时间戳防止缓存
const finalUrl = `${url}?t=${Date.now()}`;
// 设置加载事件
imgElement.onload = function() {
console.log('切片图像加载成功');
};
imgElement.onerror = function() {
console.error('切片图像加载失败');
imgElement.src = ''; // 清空错误的图像
showMessage('切片图像加载失败', 'error');
};
// 加载图像
imgElement.src = finalUrl;
// 如果有结节标记功能,更新结节标记
if (window.renderNoduleMarkers) {
window.renderNoduleMarkers(sliceIndex);
}
}
// 获取检测结果
function fetchResults() {
console.log('获取检测结果...');
// 检查会话ID是否存在
if (!currentSessionId) {
console.error('无法获取结果: 没有会话ID');
return;
}
// 更新UI状态
document.getElementById('progress-message').textContent = '加载检测结果...';
// 加载肺部切片数据(如果尚未加载)
if (!lungSegmentationLoaded) {
loadLungSegmentation();
}
// 获取结果
fetch(`/api/results/${currentSessionId}`)
.then(response => {
if (!response.ok) {
throw new Error(`服务器返回错误状态: ${response.status}`);
}
return response.json();
})
.then(data => {
console.log('检测结果:', data);
// 更新UI状态
if (data.nodules && data.nodules.length > 0) {
updateUIState('detected');
document.getElementById('progress-message').textContent = `检测到 ${data.nodules.length} 个结节`;
// 保存结节数据到全局变量
window.nodules = data.nodules;
// 更新结节列表
updateNoduleList(data.nodules);
// 创建结节标记
createNoduleMarkers();
} else {
updateUIState('detected');
document.getElementById('progress-message').textContent = '未检测到结节';
updateNoduleList([]);
}
})
.catch(error => {
console.error('获取结果出错:', error);
document.getElementById('progress-message').textContent = `获取结果失败: ${error.message}`;
showMessage(`获取结果出错: ${error.message}`, 'error');
});
}
// 更新结节列表
function updateNoduleList(nodules) {
console.log('更新结节列表,数量:', nodules ? nodules.length : 0);
const noduleListContainer = document.getElementById('nodule-list');
if (!noduleListContainer) {
console.error('找不到结节列表容器');
return;
}
// 清空现有列表
noduleListContainer.innerHTML = '';
// 更新结节计数
const noduleCountElement = document.getElementById('nodules-count');
if (noduleCountElement) {
noduleCountElement.textContent = nodules ? nodules.length : '0';
}
// 如果没有结节
if (!nodules || nodules.length === 0) {
noduleListContainer.innerHTML = '未检测到结节
';
return;
}
// 添加结节到列表
nodules.forEach((nodule, index) => {
try {
const noduleItem = document.createElement('div');
noduleItem.className = 'nodule-item';
noduleItem.dataset.id = nodule.id || index;
// 提取结节信息
const diameter = (nodule.diameter_mm || 10).toFixed(1);
const probability = ((nodule.probability || 0.5) * 100).toFixed(0);
// 检查坐标数据
let coordsStr = '[未知]';
if (nodule.voxel_coords && Array.isArray(nodule.voxel_coords)) {
coordsStr = `[${nodule.voxel_coords.map(v => Math.round(v)).join(', ')}]`;
}
// 结节名称和详细信息
noduleItem.innerHTML = `
直径:
${diameter} mm
位置:
${coordsStr}
`;
// 添加到列表
noduleListContainer.appendChild(noduleItem);
} catch (error) {
console.error(`处理结节 #${index+1} 时出错:`, error);
}
});
// 默认选中第一个结节
if (nodules.length > 0) {
const firstNoduleItem = noduleListContainer.querySelector('.nodule-item');
if (firstNoduleItem) {
firstNoduleItem.classList.add('selected');
const noduleId = firstNoduleItem.dataset.id;
loadNoduleDetails(noduleId);
}
}
}
// 创建结节标记
function createNoduleMarkers() {
if (!window.nodules || !window.nodules.length) {
console.log('没有结节数据可用于创建标记');
return;
}
console.log(`为 ${window.nodules.length} 个结节创建标记`);
// 定义全局渲染函数
window.renderNoduleMarkers = function(currentZIndex) {
const imageContainer = document.querySelector('.slice-image-container');
const sliceImg = document.getElementById('slice-img');
if (!imageContainer || !sliceImg || !sliceImg.complete) {
console.log('图像容器或图像未准备好,无法渲染结节标记');
return;
}
// 清除所有现有标记
const existingMarkers = document.querySelectorAll('.nodule-marker');
existingMarkers.forEach(marker => marker.remove());
// 获取图像尺寸
const imgWidth = sliceImg.width;
const imgHeight = sliceImg.height;
if (imgWidth <= 0 || imgHeight <= 0) {
console.log('图像尺寸无效,无法渲染标记');
return;
}
// 获取容器尺寸
const containerRect = imageContainer.getBoundingClientRect();
// 为当前切片中的每个结节创建标记
window.nodules.forEach((nodule, index) => {
try {
// 检查结节是否有有效坐标
if (!nodule.voxel_coords || !Array.isArray(nodule.voxel_coords) || nodule.voxel_coords.length < 3) {
console.warn(`结节 #${index+1} 没有有效坐标`);
return;
}
// 提取坐标
const [x, y, z] = nodule.voxel_coords.map(Math.round);
// 仅显示当前Z切片上或附近的结节(容许一定误差)
const zTolerance = 1; // 允许的Z轴误差范围
if (Math.abs(z - currentZIndex) > zTolerance) {
return;
}
// 计算标记在图像上的位置(相对于图像尺寸)
// 注意:根据实际图像坐标系统可能需要调整
const relativeX = x / maxSliceIndex; // 假设x坐标范围与切片数量相关
const relativeY = y / maxSliceIndex; // 假设y坐标范围与切片数量相关
// 将相对坐标转换为绝对像素坐标
const markerX = relativeX * imgWidth;
const markerY = relativeY * imgHeight;
// 创建标记元素
const marker = document.createElement('div');
marker.className = 'nodule-marker';
marker.dataset.id = nodule.id || index;
// 设置标记大小(根据结节直径调整)
const size = Math.max(16, (nodule.diameter_mm || 5) * 2);
marker.style.width = `${size}px`;
marker.style.height = `${size}px`;
// 设置标记位置
marker.style.left = `${markerX}px`;
marker.style.top = `${markerY}px`;
// 根据结节概率设置样式
const probability = (nodule.probability || 0.5) * 100;
if (probability > 70) {
marker.classList.add('high-prob');
} else if (probability > 30) {
marker.classList.add('medium-prob');
} else {
marker.classList.add('low-prob');
}
// 添加标记编号
marker.innerHTML = `${index + 1}`;
// 添加点击事件
marker.addEventListener('click', function() {
// 高亮显示对应的结节列表项
highlightNoduleInList(this.dataset.id);
// 加载结节详情
loadNoduleDetails(this.dataset.id);
});
// 添加到容器
imageContainer.appendChild(marker);
} catch (error) {
console.error(`为结节 #${index+1} 创建标记时出错:`, error);
}
});
};
// 为当前切片渲染标记
window.renderNoduleMarkers(currentSliceIndex);
}
// 高亮显示结节列表中的特定结节
function highlightNoduleInList(noduleId) {
// 移除所有当前选中项
document.querySelectorAll('.nodule-item.selected').forEach(item => {
item.classList.remove('selected');
});
// 找到并选中指定结节
const noduleItem = document.querySelector(`.nodule-item[data-id="${noduleId}"]`);
if (noduleItem) {
noduleItem.classList.add('selected');
// 确保选中项在视图中可见
noduleItem.scrollIntoView({ behavior: 'smooth', block: 'nearest' });
}
}
// 加载结节详情
function loadNoduleDetails(noduleId) {
console.log(`加载结节详情,ID: ${noduleId}`);
if (!window.nodules) {
console.error('没有结节数据可用');
return;
}
// 找到指定结节
const noduleIndex = parseInt(noduleId);
console.log('window nodule 有哪些数据\n', window.nodules);
// 通过ID查找结节数据
let nodule = null;
// 将noduleId转换为整数
const targetId = parseInt(noduleId);
// 遍历结节数组查找匹配的ID
for (let i = 0; i < window.nodules.length; i++) {
// 确保结节的id属性转换为整数进行比较
if (parseInt(window.nodules[i].id) === targetId) {
nodule = window.nodules[i];
console.log(`找到匹配的结节,ID: ${targetId}`);
break;
}
}
// 如果找不到匹配的结节,则使用索引方式获取
if (!nodule) {
console.warn(`未找到ID为${targetId}的结节,尝试使用索引方式获取`);
// 尝试使用索引方式获取
nodule = window.nodules[noduleIndex];
}
if (!nodule) {
console.error(`找不到ID为 ${noduleId} 的结节`);
return;
}
// 获取详情容器
const detailsContainer = document.getElementById('nodule-detail-area');
if (!detailsContainer) {
console.error('找不到结节详情容器');
return;
}
// 清空当前详情
detailsContainer.innerHTML = '';
try {
// 提取结节数据
const diameter = (nodule.diameter_mm || 10).toFixed(1);
const probability = ((nodule.probability || 0.5) * 100).toFixed(0);
const coords = nodule.voxel_coords || [0, 0, 0];
const [x, y, z] = coords.map(v => Math.round(v));
// 如果结节有Z坐标,跳转到对应切片
if (typeof z === 'number' && z >= 0 && z <= maxSliceIndex) {
currentSliceIndex = z;
loadSlice(z);
}
// 创建详情内容
detailsContainer.innerHTML = `
结节 #${noduleIndex + 1} 详情
直径:
${diameter} mm
恶性概率:
${probability}%
位置坐标:
[${x}, ${y}, ${z}]
类型:
${nodule.type || '未分类'}
`;
// 添加按钮事件监听器
document.getElementById('goto-nodule-btn').addEventListener('click', function() {
// 定位到结节所在切片
if (typeof z === 'number' && z >= 0 && z <= maxSliceIndex) {
currentSliceIndex = z;
loadSlice(z);
// 高亮显示结节标记
setTimeout(() => {
const marker = document.querySelector(`.nodule-marker[data-id="${noduleId}"]`);
if (marker) {
marker.classList.add('highlight');
// 移除高亮状态
setTimeout(() => {
marker.classList.remove('highlight');
}, 2000);
}
}, 100);
}
});
document.getElementById('nodule-report-btn').addEventListener('click', function() {
showMessage('结节报告功能尚未实现', 'info');
});
} catch (error) {
console.error('加载结节详情出错:', error);
detailsContainer.innerHTML = `加载结节详情失败: ${error.message}
`;
}
}
// 重置视图
function resetView() {
console.log('重置视图');
// 如果没有数据,不执行任何操作
if (!lungSegmentationLoaded || maxSliceIndex <= 0) {
console.log('没有数据可重置');
return;
}
// 将当前切片索引重置为中间位置
currentSliceIndex = Math.floor(maxSliceIndex / 2);
// 加载中间切片
loadSlice(currentSliceIndex);
// 显示消息
showMessage('视图已重置', 'info');
}
// 显示消息
function showMessage(message, type = 'info') {
console.log(`显示消息: ${message} (${type})`);
// 创建消息容器(如果不存在)
let msgContainer = document.getElementById('message-container');
if (!msgContainer) {
msgContainer = document.createElement('div');
msgContainer.id = 'message-container';
document.body.appendChild(msgContainer);
}
// 创建消息元素
const msgElement = document.createElement('div');
msgElement.className = `message ${type}`;
msgElement.innerHTML = `
${message}
`;
// 添加到容器
msgContainer.appendChild(msgElement);
// 显示消息
setTimeout(() => {
msgElement.classList.add('show');
}, 10);
// 设置自动消失
setTimeout(() => {
msgElement.classList.remove('show');
setTimeout(() => {
msgElement.remove();
}, 300);
}, 3000);
}
// 更新UI状态
function updateUIState(state) {
// 移除所有状态类
document.body.classList.remove(
'state-initial',
'state-uploading',
'state-detecting',
'state-detected'
);
// 添加新状态类
document.body.classList.add(`state-${state}`);
// 更新UI元素可见性
switch (state) {
case 'initial':
document.getElementById('step-upload').classList.add('active');
document.getElementById('step-detect').classList.remove('active');
document.getElementById('step-visualize').classList.remove('active');
break;
case 'uploading':
document.getElementById('step-upload').classList.add('active');
document.getElementById('step-detect').classList.remove('active');
document.getElementById('step-visualize').classList.remove('active');
break;
case 'detecting':
document.getElementById('step-upload').classList.add('completed');
document.getElementById('step-detect').classList.add('active');
document.getElementById('step-visualize').classList.remove('active');
break;
case 'detected':
document.getElementById('step-upload').classList.add('completed');
document.getElementById('step-detect').classList.add('completed');
document.getElementById('step-visualize').classList.add('active');
break;
}
}
================================================
FILE: deploy/run.py
================================================
import os
import sys
import subprocess
import platform
import webbrowser
import time
def get_python_command():
"""获取Python命令"""
if platform.system() == "Windows":
return "python"
else:
return "python3"
def check_dependencies():
"""检查必要的依赖是否已安装"""
try:
import flask
import numpy
import tensorflow
import SimpleITK
print("✓ 所有必要的依赖已安装")
return True
except ImportError as e:
print(f"✗ 缺少依赖: {e}")
print("请安装所需依赖: pip install flask flask-cors numpy tensorflow SimpleITK")
return False
def create_directories():
"""创建必要的目录"""
os.makedirs("backend/uploads", exist_ok=True)
os.makedirs("backend/models", exist_ok=True)
print("✓ 目录创建完成")
def run_backend_server():
"""运行后端服务器"""
python_cmd = get_python_command()
# 构建命令
cmd = [python_cmd, "backend/app.py"]
# 启动后端服务器
print("\n启动后端服务器...")
process = subprocess.Popen(cmd)
# 等待服务器启动
time.sleep(2)
# 打开浏览器
print("正在打开浏览器...")
webbrowser.open("http://localhost:5000")
print("\n服务器已启动!\n")
print("在浏览器中访问: http://localhost:5000")
print("按 Ctrl+C 停止服务器")
try:
process.wait()
except KeyboardInterrupt:
print("\n正在停止服务器...")
process.terminate()
process.wait()
print("服务器已停止")
def main():
"""主函数"""
# 切换到脚本所在目录
script_dir = os.path.dirname(os.path.abspath(__file__))
os.chdir(script_dir)
print("CT图像分析系统启动工具")
print("=" * 30)
# 检查依赖
if not check_dependencies():
return
# 创建必要的目录
create_directories()
# 运行后端服务器
run_backend_server()
if __name__ == "__main__":
main()
================================================
FILE: inference/__init__.py
================================================
================================================
FILE: inference/c3d_classify_result-1.3.6.1.4.1.14519.5.2.1.6279.6001.149041668385192796520281592139.csv
================================================
patient_id,nodule_id,voxel_x,voxel_y,voxel_z,world_x,world_y,world_z,diameter_mm,prob
1.3.6.1.4.1.14519.5.2.1.6279.6001.149041668385192796520281592139,1,253,256,81,58.0,76.0,-269.5,16.0,0.9999972581863403
1.3.6.1.4.1.14519.5.2.1.6279.6001.149041668385192796520281592139,9,192,180,95,-3.0,0.0,-255.5,16.0,0.9858711361885071
================================================
FILE: inference/classifier.py
================================================
import numpy,pandas
import os
from util import progress_watch
from detector import extract_dicom_images_patient,get_papaya_coords,prepare_image_for_net3D,predict_nodule_type
from keras.models import load_model, model_from_json
from keras.optimizers import SGD
from util.image_util import load_patient_images, rescale_patient_images
from util.ml.metrics import get_3d_pixel_l2_distance
from util.progress_watch import Stopwatch
PREDICT_STEP = 12
CUBE_SIZE = 32
P_TH = 0.85
def scan(dicom_path, only_patient_id, workspace):
target_dir = workspace
boxes = []
centers = []
CONTINUE_JOB = True
sw = Stopwatch.start_new()
pixel_spacing, dicom_size, png_size, invert_order = extract_dicom_images_patient(dicom_path, target_dir)
print("png_size: ", png_size)
final_nodules_df = multipule_test(workspace, only_patient_id, CONTINUE_JOB)
final_nodules_df_sort = final_nodules_df.sort_values(['nodule_chance'], ascending=False)
print("predict from maligancy...")
print("*"*20)
print(final_nodules_df_sort)
print("*" * 20)
if len(final_nodules_df_sort) == 0:
return boxes, centers
ggn_npy = predict_nodule_type(final_nodules_df_sort, png_size, workspace)
json_file = open('./models/c3d_malignancy_regreession.json', 'r')
loaded_model_json = json_file.read()
json_file.close()
loaded_model = model_from_json(loaded_model_json)
model_weight = './models/c3d_malignancy_regreession_04_0.8719.hd5'
loaded_model.load_weights(model_weight)
print("Loaded model_0 from disk")
sgd = SGD(lr=0.0001, decay=1e-6, momentum=0.9, nesterov=True)
loaded_model.compile(loss='categorical_crossentropy', optimizer=sgd, metrics=['accuracy'])
model_result = loaded_model.predict(ggn_npy, batch_size=20, verbose=1)
ggn_class_list = []
for ii in model_result:
print(ii)
ii_index = numpy.argmax(ii)
print("result from malignancy..")
print(generate_ggn_class(ii_index), round(ii[ii_index], 3))
# smaller than 0.5 means not malignancy
if round(ii[ii_index], 3)> 0.5:
ggn_class_list.append([generate_ggn_class(ii_index), round(ii[ii_index], 3)])
i = 0
for index, row in final_nodules_df_sort.iterrows():
print(ggn_class_list[i])
coord_z = row["coord_z"]
coord_y = row["coord_y"]
coord_x = row["coord_x"]
print("index-x-y-z-p", coord_x, coord_y, coord_z, row["nodule_chance"])
box, center = get_papaya_coords(coord_x, coord_y, coord_z, row["nodule_chance"], pixel_spacing, dicom_size,
png_size, invert_order, ggn_class_list[i])
boxes.append(box)
centers.append(center)
# draw_overlay(target_dir, coord_x, coord_y, coord_z, str(i))
# draw_overlay_dicom(pixels, only_patient_id, coord_x, coord_y, coord_z, str(i), pixel_spacing, dicom_size,
# png_size, invert_order, target_dir)
i += 1
print("ALL Complete in : ", sw.get_elapsed_seconds(), " seconds")
return boxes, centers
def generate_ggn_class(ii_index):
if ii_index == 0:
return 'non_malignancy'
if ii_index == 1:
return 'malignancy'
def multipule_test(workspace, only_patient_id, CONTINUE_JOB):
temp_df = []
for model_version in ["model_loc.hd5", "model_loc_val_0.96.hd5"]:
print("gpu begin:")
pred_nodules_df = locate_malignancy(workspace, "models/" + model_version, CONTINUE_JOB, only_patient_id=only_patient_id,
magnification=1, flip=False, train_data=True, holdout_no=None,
ext_name="luna16_fs")
pred_nodules_df = pred_nodules_df[pred_nodules_df["nodule_chance"] > P_TH]
temp_df.append(pred_nodules_df)
temp_dataframe = pandas.concat(temp_df)
df = reduce_predicts_same_slice(temp_dataframe)
# df = temp_df
return df
def locate_malignancy(png_path, model_weight,CONTINUE_JOB, only_patient_id,
magnification=1, flip=False, train_data=True, holdout_no=None,
ext_name="luna16_fs"):
patient_id = only_patient_id
all_predictions_csv = []
sw = helpers.Stopwatch.start_new()
#json_file = open('/home/xiatao/workspace/renji_dicom/model_json/c3d_5_label_classify.json', 'r')
json_file = open('../service/workdir/model_loc.json', 'r')
loaded_model_json = json_file.read()
json_file.close()
model = model_from_json(loaded_model_json)
# load weights into new model
model.load_weights(model_weight)
patient_img = load_patient_images(png_path + "/png/", "*_i.png", [])
if magnification != 1:
patient_img = rescale_patient_images(patient_img, (1, 1, 1), magnification)
patient_mask = load_patient_images(png_path + "/png/", "*_m.png", [])
if magnification != 1:
patient_mask = rescale_patient_images(patient_mask, (1, 1, 1), magnification, is_mask_image=True)
step = PREDICT_STEP
CROP_SIZE = CUBE_SIZE
predict_volume_shape_list = [0, 0, 0]
for dim in range(3):
dim_indent = 0
while dim_indent + CROP_SIZE < patient_img.shape[dim]:
predict_volume_shape_list[dim] += 1
dim_indent += step
predict_volume_shape = (predict_volume_shape_list[0], predict_volume_shape_list[1], predict_volume_shape_list[2])
predict_volume = numpy.zeros(shape=predict_volume_shape, dtype=float)
print("Predict volume shape: ", predict_volume.shape)
done_count = 0
skipped_count = 0
batch_size = 32
batch_list = []
batch_list_coords = []
patient_predictions_csv = []
cube_img = None
annotation_index = 0
for z in range(0, predict_volume_shape[0]):
for y in range(0, predict_volume_shape[1]):
for x in range(0, predict_volume_shape[2]):
# if cube_img is None:
cube_img = patient_img[z * step:z * step + CROP_SIZE, y * step:y * step + CROP_SIZE,
x * step:x * step + CROP_SIZE]
cube_mask = patient_mask[z * step:z * step + CROP_SIZE, y * step:y * step + CROP_SIZE,
x * step:x * step + CROP_SIZE]
if cube_mask.sum() < 2000:
skipped_count += 1
else:
if flip:
cube_img = cube_img[:, :, ::-1]
img_prep = prepare_image_for_net3D(cube_img)
batch_list.append(img_prep)
batch_list_coords.append((z, y, x))
if len(batch_list) % batch_size == 0:
batch_data = numpy.vstack(batch_list)
p = model.predict(batch_data, batch_size=batch_size)
for i in range(len(p[0])):
p_z = batch_list_coords[i][0]
p_y = batch_list_coords[i][1]
p_x = batch_list_coords[i][2]
# print("before malignancy_chance")
# print(p[0])
malignancy_chance = p[0][i][0]
predict_volume[p_z, p_y, p_x] = malignancy_chance
if malignancy_chance > P_TH:
p_z = p_z * step + CROP_SIZE / 2
p_y = p_y * step + CROP_SIZE / 2
p_x = p_x * step + CROP_SIZE / 2
p_z_perc = round(p_z / patient_img.shape[0], 4)
p_y_perc = round(p_y / patient_img.shape[1], 4)
p_x_perc = round(p_x / patient_img.shape[2], 4)
nodule_chance = round(malignancy_chance, 4)
patient_predictions_csv_line = [annotation_index, p_x_perc, p_y_perc, p_z_perc,
nodule_chance]
patient_predictions_csv.append(patient_predictions_csv_line)
all_predictions_csv.append([patient_id] + patient_predictions_csv_line)
annotation_index += 1
batch_list = []
batch_list_coords = []
done_count += 1
if done_count % 10000 == 0:
print("Scan: ", done_count, " skipped:", skipped_count)
df = pandas.DataFrame(patient_predictions_csv,
columns=["anno_index", "coord_x", "coord_y", "coord_z", "nodule_chance"])
filter_patient_nodules_predictions(df, patient_id, CROP_SIZE * magnification, png_path)
print(predict_volume.mean())
print("GPU costs : ", sw.get_elapsed_seconds(), " seconds")
return df
def filter_patient_nodules_predictions(df_nodule_predictions: pandas.DataFrame, patient_id, view_size, png_path):
patient_mask = load_patient_images(png_path+"/png/", "*_m.png")
delete_indices = []
for index, row in df_nodule_predictions.iterrows():
z_perc = row["coord_z"]
y_perc = row["coord_y"]
center_x = int(round(row["coord_x"] * patient_mask.shape[2]))
center_y = int(round(y_perc * patient_mask.shape[1]))
center_z = int(round(z_perc * patient_mask.shape[0]))
start_y = center_y - view_size / 2
start_x = center_x - view_size / 2
nodule_in_mask = False
for z_index in [-1, 0, 1]:
img = patient_mask[z_index + center_z]
start_x = int(start_x)
start_y = int(start_y)
view_size = int(view_size)
img_roi = img[start_y:start_y + view_size, start_x:start_x + view_size]
if img_roi.sum() > 255: # more than 1 pixel of mask.
nodule_in_mask = True
if not nodule_in_mask:
print("Nodule not in mask: ", (center_x, center_y, center_z))
#delete_indices.append(df_nodule_predictions.loc[index,])
delete_indices.append(index)
else:
if center_z < 30:
print("Z < 30: ", patient_id, " center z:", center_z, " y_perc: ", y_perc)
#delete_indices.append(df_nodule_predictions.loc[index])
delete_indices.append(index)
if (z_perc > 0.75 or z_perc < 0.25) and y_perc > 0.85:
print("SUSPICIOUS FALSEPOSITIVE: ", patient_id, " center z:", center_z, " y_perc: ", y_perc)
#delete_indices.append(df_nodule_predictions.loc[index])
delete_indices.append(index)
if center_z < 50 and y_perc < 0.30:
print("SUSPICIOUS FALSEPOSITIVE OUT OF RANGE: ", patient_id, " center z:", center_z, " y_perc: ",
y_perc)
#delete_indices.append(df_nodule_predictions.loc[index])
delete_indices.append(index)
print("slice to drop:\t",delete_indices)
df_nodule_predictions.drop(df_nodule_predictions.index[delete_indices], inplace=True)
return df_nodule_predictions
def reduce_predicts_same_slice(pred_nodules_df):
rows_filter = []
pred_nodules_df_local = pred_nodules_df.sort_values(["coord_z"], ascending=False)
if len(pred_nodules_df_local) <= 1:
return pred_nodules_df_local
compare_row = pred_nodules_df_local.iloc[0]
for row_index, row in pred_nodules_df_local[1:].iterrows():
if compare_row["coord_z"] == row["coord_z"]:
dist = get_3d_pixel_l2_distance(compare_row, row)
if dist > 0.2:
rows_filter.append(row)
else:
rows_filter.append(compare_row)
compare_row = row
if len(rows_filter) == 0:
rows_filter.append(compare_row)
last_row = rows_filter[len(rows_filter)-1]
if last_row["coord_z"] != compare_row["coord_z"]:
rows_filter.append(compare_row)
columns = ["anno_index", "coord_x", "coord_y", "coord_z", "nodule_chance"]
res_df = pandas.DataFrame(rows_filter, columns=columns)
return res_df
================================================
FILE: inference/detector.py
================================================
from service import settings_jjyang
import cv2
import pandas
import os
import glob
import numpy
from keras import backend as K
from keras.models import model_from_json
from util.cube import save_cube_img, get_cube_from_img
from util.image.processing import normalize_hu_values
from util.dicom_util import get_pixels_hu, extract_dicom_images_patient, load_dicom_slices
from util.image_util import prepare_image_for_net3D, load_patient_images, rescale_patient_images
from util.ml.metrics import get_3d_pixel_l2_distance
from util.progress_watch import Stopwatch
K.set_image_data_format("channels_last") # 更新为TF2方式设置
import tensorflow as tf
# 在TF2中设置GPU内存使用方式
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
try:
# 将GPU内存使用限制为可用内存的30%
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
tf.config.experimental.set_virtual_device_configuration(
gpu,
[tf.config.experimental.VirtualDeviceConfiguration(memory_limit=1024 * 3)] # 约30%的10GB显存
)
except RuntimeError as e:
print(e)
CUBE_SIZE = 32
MEAN_PIXEL_VALUE = settings_jjyang.MEAN_PIXEL_VALUE_NODULE
NEGS_PER_POS = 20
P_TH = 0.7
PREDICT_STEP = 12
USE_DROPOUT = False
BOX_size = 20
BOX_depth = 9
# NODULE_CHANCE = 0.5
NODULE_DIAMM = 1.0
CUBE_IMGTYPE_SRC = "_i"
def filter_patient_nodules_predictions(df_nodule_predictions: pandas.DataFrame, patient_id, view_size, png_path):
patient_mask = load_patient_images(png_path+"/png/", "*_m.png")
delete_indices = []
for index, row in df_nodule_predictions.iterrows():
z_perc = row["coord_z"]
y_perc = row["coord_y"]
center_x = int(round(row["coord_x"] * patient_mask.shape[2]))
center_y = int(round(y_perc * patient_mask.shape[1]))
center_z = int(round(z_perc * patient_mask.shape[0]))
mal_score = row["diameter_mm"]
start_y = center_y - view_size / 2
start_x = center_x - view_size / 2
nodule_in_mask = False
for z_index in [-1, 0, 1]:
img = patient_mask[z_index + center_z]
start_x = int(start_x)
start_y = int(start_y)
view_size = int(view_size)
img_roi = img[start_y:start_y + view_size, start_x:start_x + view_size]
if img_roi.sum() > 255: # more than 1 pixel of mask.
nodule_in_mask = True
if not nodule_in_mask:
print("Nodule not in mask: ", (center_x, center_y, center_z))
if mal_score > 0:
mal_score *= -1
df_nodule_predictions.loc[index, "diameter_mm"] = mal_score
else:
if center_z < 30:
print("Z < 30: ", patient_id, " center z:", center_z, " y_perc: ", y_perc)
if mal_score > 0:
mal_score *= -1
df_nodule_predictions.loc[index, "diameter_mm"] = mal_score
if (z_perc > 0.75 or z_perc < 0.25) and y_perc > 0.85:
print("SUSPICIOUS FALSEPOSITIVE: ", patient_id, " center z:", center_z, " y_perc: ", y_perc)
if center_z < 50 and y_perc < 0.30:
print("SUSPICIOUS FALSEPOSITIVE OUT OF RANGE: ", patient_id, " center z:", center_z, " y_perc: ",
y_perc)
df_nodule_predictions.drop(df_nodule_predictions.index[delete_indices], inplace=True)
return df_nodule_predictions
def predict_cubes(png_path, model_path, only_patient_id=None, magnification=1, flip=False):
patient_id = only_patient_id
all_predictions_csv = []
sw = helpers.Stopwatch.start_new()
json_file = open('../service/workdir/model_loc.json', 'r')
loaded_model_json = json_file.read()
json_file.close()
model = model_from_json(loaded_model_json)
# load weights into new model
model.load_weights(model_path)
patient_img = load_patient_images(png_path+"/png/", "*_i.png", [])
if magnification != 1:
patient_img = rescale_patient_images(patient_img, (1, 1, 1), magnification)
patient_mask = load_patient_images(png_path+"/png/", "*_m.png", [])
if magnification != 1:
patient_mask = rescale_patient_images(patient_mask, (1, 1, 1), magnification, is_mask_image=True)
step = PREDICT_STEP
CROP_SIZE = CUBE_SIZE
predict_volume_shape_list = [0, 0, 0]
for dim in range(3):
dim_indent = 0
while dim_indent + CROP_SIZE < patient_img.shape[dim]:
predict_volume_shape_list[dim] += 1
dim_indent += step
predict_volume_shape = (predict_volume_shape_list[0], predict_volume_shape_list[1], predict_volume_shape_list[2])
predict_volume = numpy.zeros(shape=predict_volume_shape, dtype=float)
print("Predict volume shape: ", predict_volume.shape)
done_count = 0
skipped_count = 0
batch_size = 32
batch_list = []
batch_list_coords = []
patient_predictions_csv = []
cube_img = None
annotation_index = 0
for z in range(0, predict_volume_shape[0]):
for y in range(0, predict_volume_shape[1]):
for x in range(0, predict_volume_shape[2]):
# if cube_img is None:
cube_img = patient_img[z * step:z * step + CROP_SIZE, y * step:y * step + CROP_SIZE,
x * step:x * step + CROP_SIZE]
cube_mask = patient_mask[z * step:z * step + CROP_SIZE, y * step:y * step + CROP_SIZE,
x * step:x * step + CROP_SIZE]
if cube_mask.sum() < 2000:
skipped_count += 1
else:
if flip:
cube_img = cube_img[:, :, ::-1]
img_prep = prepare_image_for_net3D(cube_img)
batch_list.append(img_prep)
batch_list_coords.append((z, y, x))
if len(batch_list) % batch_size == 0:
batch_data = numpy.vstack(batch_list)
p = model.predict(batch_data, batch_size=batch_size)
for i in range(len(p[0])):
p_z = batch_list_coords[i][0]
p_y = batch_list_coords[i][1]
p_x = batch_list_coords[i][2]
nodule_chance = p[0][i][0]
predict_volume[p_z, p_y, p_x] = nodule_chance
if nodule_chance > P_TH:
p_z = p_z * step + CROP_SIZE / 2
p_y = p_y * step + CROP_SIZE / 2
p_x = p_x * step + CROP_SIZE / 2
p_z_perc = round(p_z / patient_img.shape[0], 4)
p_y_perc = round(p_y / patient_img.shape[1], 4)
p_x_perc = round(p_x / patient_img.shape[2], 4)
diameter_mm = round(p[1][i][0], 4)
diameter_perc = round(diameter_mm / patient_img.shape[2], 4)
nodule_chance = round(nodule_chance, 4)
patient_predictions_csv_line = [annotation_index, p_x_perc, p_y_perc, p_z_perc,
diameter_perc, nodule_chance, diameter_mm]
patient_predictions_csv.append(patient_predictions_csv_line)
all_predictions_csv.append([patient_id] + patient_predictions_csv_line)
annotation_index += 1
batch_list = []
batch_list_coords = []
done_count += 1
if done_count % 10000 == 0:
print("Scan: ", done_count, " skipped:", skipped_count)
df = pandas.DataFrame(patient_predictions_csv,
columns=["anno_index", "coord_x", "coord_y", "coord_z", "diameter", "nodule_chance",
"diameter_mm"])
filter_patient_nodules_predictions(df, patient_id, CROP_SIZE * magnification, png_path)
# df.to_csv(settings_jjyang.BASE_DIR_SSD+"temp_dir/" + patient_id + ".csv", index=False)
print(predict_volume.mean())
print("GPU costs : ", sw.get_elapsed_seconds(), " seconds")
return df
def reduce_predicts_same_slice(pred_nodules_df):
rows_filter = []
pred_nodules_df_local = pred_nodules_df.sort_values(["coord_z", "diameter_mm"], ascending=False)
if len(pred_nodules_df_local) <= 1:
return pred_nodules_df_local
i = 0
compare_row = pred_nodules_df_local.iloc[0]
for row_index, row in pred_nodules_df_local[1:].iterrows():
if compare_row["coord_z"] == row["coord_z"]:
dist = get_3d_pixel_l2_distance(compare_row, row)
if dist > 0.2:
rows_filter.append(row)
else:
rows_filter.append(compare_row)
compare_row = row
i += 1
if len(rows_filter) == 0:
rows_filter.append(compare_row)
last_row = rows_filter[len(rows_filter)-1]
if last_row["coord_z"] != compare_row["coord_z"]:
rows_filter.append(compare_row)
columns = ["anno_index", "coord_x", "coord_y", "coord_z", "diameter", "nodule_chance", "diameter_mm"]
res_df = pandas.DataFrame(rows_filter, columns=columns)
return res_df
def multipule_test(workspace, only_patient_id, CONTINUE_JOB):
temp_df = []
for model_version in ["model_loc.hd5", "model_loc_val_0.96.hd5"]:
print("gpu begin:")
pred_nodules_df = predict_cubes(workspace, "models/" + model_version, CONTINUE_JOB, only_patient_id=only_patient_id,
magnification=1, flip=False, train_data=True, holdout_no=None,
ext_name="luna16_fs")
pred_nodules_df = pred_nodules_df[pred_nodules_df["nodule_chance"] > P_TH]
pred_nodules_df = pred_nodules_df[pred_nodules_df["diameter_mm"] > NODULE_DIAMM]
temp_df.append(pred_nodules_df)
temp_dataframe = pandas.concat(temp_df)
df = reduce_predicts_same_slice(temp_dataframe)
# df = temp_df
return df
def draw_overlay_dicom(pixels, coord_x, coord_y, coord_z, i, pixel_spacing, dicom_size, png_size, invert_order, png_path):
z = int(coord_z * png_size[0])
y = int(coord_y * png_size[1])
x = int(coord_x * png_size[2])
dicom_z = int(z / pixel_spacing[2])
dicom_y = int(y / pixel_spacing[1])
dicom_x = int(x / pixel_spacing[0])
x1 = dicom_x - BOX_size
y1 = dicom_y - BOX_size
x2 = dicom_x + BOX_size
y2 = dicom_y + BOX_size
print("invert_order:", invert_order)
if invert_order:
new_z = dicom_z
org_img = pixels[new_z]
print("dicom_coord_x_y_z: ", dicom_x, dicom_y, new_z)
else:
new_z = dicom_size[0] - dicom_z
org_img = pixels[new_z]
# print("dicom_coord_x_y_z: ", dicom_size[2] - dicom_x, dicom_y, new_z) #for papaya reverse left-right
print("dicom_coord_x_y_z: ", dicom_x, dicom_y, new_z)
org_img = normalize_hu_values(org_img)
cv2.rectangle(org_img, (x1, y1), (x2, y2), (255, 0, 0), 1)
# suffix = i + "_" + str(dicom_size[0] - dicom_z)
suffix = i + "_" + str(new_z)
cv2.imwrite(png_path + "/" + "overlay_dicom" + suffix + ".png", org_img * 255)
def get_papaya_coords(coord_x, coord_y, coord_z, nodule_chance, pixel_spacing, dicom_size, png_size, invert_order, ggn_class):
z = int(coord_z * png_size[0])
y = int(coord_y * png_size[1])
x = int(coord_x * png_size[2])
dicom_z = int(z / pixel_spacing[2])
dicom_y = int(y / pixel_spacing[1])
dicom_x = int(x / pixel_spacing[0])
# print("invert_order:", invert_order)
if invert_order:
new_z = dicom_size[0] - dicom_z
new_x = dicom_size[2] - dicom_x
new_y = dicom_y
print("invert_order: ", invert_order, "new_coord_z_y_x: ", new_z, new_y, new_x, " dicom_coord_z_y_x: ", dicom_z, dicom_y, dicom_x)
else:
new_z = dicom_size[0] - dicom_z
new_x = dicom_size[2] - dicom_x
new_y = dicom_y
# print("dicom_coord_x_y_z: ", dicom_size[2] - dicom_x, dicom_y, new_z) #for papaya reverse left-right
print("invert_order: ", invert_order, "new_coord_z_y_x: ", new_z, new_y, new_x," dicom_coord_z_y_x: ", dicom_z, dicom_y, dicom_x)
x1 = new_x - BOX_size
y1 = new_y - BOX_size
x2 = new_x + BOX_size
y2 = new_y + BOX_size
z1 = new_z - BOX_depth
z2 = new_z + BOX_depth
box = [z1, y1, x1, z2, y2, x2]
center = [new_z, new_y, new_x, nodule_chance*100, ggn_class[0], ggn_class[1]*100]
return box, center
def get_papaya_coords_only(coord_x, coord_y, coord_z, nodule_chance, pixel_spacing, dicom_size, png_size, invert_order):
z = int(coord_z * png_size[0])
y = int(coord_y * png_size[1])
x = int(coord_x * png_size[2])
dicom_z = int(z / pixel_spacing[2])
dicom_y = int(y / pixel_spacing[1])
dicom_x = int(x / pixel_spacing[0])
# print("invert_order:", invert_order)
if invert_order:
new_z = dicom_size[0] - dicom_z
new_x = dicom_size[2] - dicom_x
new_y = dicom_y
print("invert_order: ", invert_order, "new_coord_z_y_x: ", new_z, new_y, new_x, " dicom_coord_z_y_x: ", dicom_z, dicom_y, dicom_x)
else:
new_z = dicom_size[0] - dicom_z
new_x = dicom_size[2] - dicom_x
new_y = dicom_y
# print("dicom_coord_x_y_z: ", dicom_size[2] - dicom_x, dicom_y, new_z) #for papaya reverse left-right
print("invert_order: ", invert_order, "new_coord_z_y_x: ", new_z, new_y, new_x," dicom_coord_z_y_x: ", dicom_z, dicom_y, dicom_x)
x1 = new_x - BOX_size
y1 = new_y - BOX_size
x2 = new_x + BOX_size
y2 = new_y + BOX_size
z1 = new_z - BOX_depth
z2 = new_z + BOX_depth
box = [z1, y1, x1, z2, y2, x2]
center = [new_z, new_y, new_x, round(nodule_chance,3)]
return box, center
def run(dicom_path, png_save_dir):
CONTINUE_JOB = True
sw = Stopwatch.start_new()
pixel_spacing, dicom_size, png_size, invert_order = extract_dicom_images_patient(dicom_path, png_save_dir)
final_nodules_df = multipule_test(png_save_dir, CONTINUE_JOB)
# final_nodules_df = pandas.read_csv(settings_jjyang.OVERLAY_PATH + only_patient_id + ".csv")
slices = load_dicom_slices(dicom_path)
pixels = get_pixels_hu(slices)
i = 0
for index, row in final_nodules_df.iterrows():
coord_z = row["coord_z"]
coord_y = row["coord_y"]
coord_x = row["coord_x"]
print("index-x-y-z-p-size", i, coord_x, coord_y, coord_z, row["nodule_chance"], row["diameter_mm"])
draw_overlay_dicom(pixels, coord_x, coord_y, coord_z, str(i), pixel_spacing, dicom_size, png_size, invert_order)
# draw_overlay(only_patient_id, coord_x, coord_y, coord_z, str(i))
i += 1
# reduce_predicts_same_slice(pred_nodules_df)
print("ALL Complete in : ", sw.get_elapsed_seconds(), " seconds")
def predict_nodule_type(df, png_size, png_dir):
list = []
new_dir = png_dir + '/png/'
images = load_patient_images(new_dir, "*" + CUBE_IMGTYPE_SRC + ".png")
i = 0
for index, row in df.iterrows():
coord_z = row["coord_z"]
coord_y = row["coord_y"]
coord_x = row["coord_x"]
z = int(coord_z * png_size[0])
y = int(coord_y * png_size[1])
x = int(coord_x * png_size[2])
print("index-x-y-z-p-size-png", x, y, z)
# print('z invert order : ', invert_order)
# if not invert_order:
# coord_z = int((dicom_size[0] - row["z"]) * pixel_spacing[2])
# else:
# coord_z = int(row["z"] * pixel_spacing[2])
cube_img = get_cube_from_img(images, x, y, z, 32)
# save_cube_img('./' + png_dir + '/_' + str(i) + '.png', cube_img, 4, 8)
save_cube_img(png_dir + '/_' + str(x)+'_'+str(y)+'_'+str(z) + '.png', cube_img, 4, 8)
img3d = prepare_image_for_net3D(cube_img)
list.append(img3d)
i += 1
img_numpy = numpy.vstack(list)
return img_numpy
def generate_ggn_class(ii_index):
if ii_index == 0:
return 'AAH'
if ii_index == 1:
return 'AIS'
if ii_index == 2:
return 'MIA'
if ii_index == 3:
return 'IA'
if ii_index == 4:
return 'OH'
def scan(dicom_path, only_patient_id, workspace, file_type="dicom"):
"""
扫描节点,支持DICOM和MHD文件
:param dicom_path: DICOM目录或MHD文件路径
:param only_patient_id: 患者ID
:param workspace: 工作目录
:param file_type: 文件类型,'dicom'或'mhd'
:return: 节点包围盒和中心点
"""
print("workspace", workspace)
try:
target_dir = workspace
p_index = 0
# 提取图像
if file_type == "dicom":
# 从DICOM提取图像
pixel_spacing, dicom_size, png_size, invert_order = extract_dicom_images_patient(dicom_path, target_dir)
else:
# 对于MHD文件,图像已经在前面的处理步骤中提取
pixel_spacing = [1.0, 1.0, 1.0] # 默认值,后续可根据需要从MHD信息中提取
dicom_size = [0, 0, 0] # 默认值
png_size = [0, 0, 0] # 默认值
invert_order = False
# 尝试从PNG目录中获取实际尺寸
png_dir = os.path.join(target_dir, 'png')
if os.path.exists(png_dir):
png_files = glob.glob(png_dir + "/*_i.png")
if png_files:
sample_img = cv2.imread(png_files[0], cv2.IMREAD_GRAYSCALE)
png_size = [len(png_files), sample_img.shape[0], sample_img.shape[1]]
# 针对特定模型进行预测
model_path = os.path.join(current_dir, '../service/workdir/jjnode_classify_resnet_best.h5')
df_pos = predict_cubes(target_dir, model_path, True, only_patient_id, False, 1, False)
# 添加分类处理
df_node = predict_nodule_type(df_pos, png_size, target_dir)
df_pos = df_node
# 获取框和中心点信息
boxes = []
record_list = []
records_png_paths = []
for index, row in df_pos.iterrows():
record = []
coord_x = row["coord_x"]
coord_y = row["coord_y"]
coord_z = row["coord_z"]
nodule_chance = float(row["nodule_chance"])
box_label = row["box_label"]
# 绘制可视化
draw_overlay(target_dir, coord_x, coord_y, coord_z, str(p_index))
try:
# 获取节点坐标
box = get_papaya_coords(
coord_x, coord_y, coord_z, nodule_chance, pixel_spacing, dicom_size, png_size, invert_order, box_label)
boxes.append(box)
# 收集节点信息
record.append(str(p_index))
record.append(coord_x)
record.append(coord_y)
record.append(coord_z)
record.append(box_label)
record.append(nodule_chance)
record_list.append(record)
# 生成立方体图像
png_f = make_test_cube(record)
records_png_paths.append(png_f)
except Exception as e:
print(f"处理节点时出错: {e}")
p_index += 1
return boxes, record_list
except Exception as e:
print(f"扫描过程中出错: {e}")
return [], []
def scan_only(dicom_path, only_patient_id, workspace, file_type="dicom"):
"""
只扫描节点,不分类,支持DICOM和MHD文件
:param dicom_path: DICOM目录或MHD文件路径
:param only_patient_id: 患者ID
:param workspace: 工作目录
:param file_type: 文件类型,'dicom'或'mhd'
:return: 节点包围盒和中心点
"""
print("workspace", workspace)
try:
target_dir = workspace
model_path = os.path.join(current_dir, '../service/workdir/jjnode_loc_resnet_best.h5')
p_index = 0
# 提取图像
if file_type == "dicom":
# 从DICOM提取图像
pixel_spacing, dicom_size, png_size, invert_order = extract_dicom_images_patient(dicom_path, target_dir)
else:
# 对于MHD文件,图像已经在前面的处理步骤中提取
pixel_spacing = [1.0, 1.0, 1.0] # 默认值,后续可根据需要从MHD信息中提取
dicom_size = [0, 0, 0] # 默认值
png_size = [0, 0, 0] # 默认值
invert_order = False
# 尝试从PNG目录中获取实际尺寸
png_dir = os.path.join(target_dir, 'png')
if os.path.exists(png_dir):
png_files = glob.glob(png_dir + "/*_i.png")
if png_files:
sample_img = cv2.imread(png_files[0], cv2.IMREAD_GRAYSCALE)
png_size = [len(png_files), sample_img.shape[0], sample_img.shape[1]]
# 预测节点
df_pos = predict_cubes(target_dir, model_path, True, only_patient_id, False, 1, False)
# 过滤预测结果
df_pos = filter_patient_nodules_predictions(df_pos, only_patient_id, BOX_size, target_dir)
df_pos = reduce_predicts_same_slice(df_pos)
# 获取框和中心点信息
boxes = []
centers = []
for index, row in df_pos.iterrows():
coord_x = row["coord_x"]
coord_y = row["coord_y"]
coord_z = row["coord_z"]
nodule_chance = float(row["nodule_chance"])
# 绘制可视化
draw_overlay(target_dir, coord_x, coord_y, coord_z, str(p_index))
try:
# 获取节点坐标
box = get_papaya_coords_only(
coord_x, coord_y, coord_z, nodule_chance, pixel_spacing, dicom_size, png_size, invert_order)
boxes.append(box)
# 节点中心点(标记百分比)
z_perc = coord_z
y_perc = coord_y
x_perc = coord_x
nodule_chance_perc = nodule_chance
diamm = NODULE_DIAMM
center = [int(z_perc * dicom_size[0]), int(y_perc * dicom_size[1]), int(x_perc * dicom_size[2]), diamm, nodule_chance_perc * 100]
centers.append(center)
except Exception as e:
print(f"处理节点时出错: {e}")
p_index += 1
return boxes, centers
except Exception as e:
print(f"扫描过程中出错: {e}")
return [], []
# horos_5358
if __name__ == "__main__":
workspace = os.path.join(os.path.abspath(os.path.dirname(__file__)),
'./static/workspace/KHAXVYV5')
scan('E:/renji_hospital_dicom/AIS/ChenZhou/IE2UNXKC/KHAXVYV5',
'KHAXVYV5', workspace)
# df = pandas.read_csv("./test.csv")
# predict_nodule_type(df,[354, 305, 305],'./static/workspace/0708c00f6117ed977bbe1b462b56848c')
================================================
FILE: inference/negative_sample_selection.py
================================================
import os
import sys
import numpy as np
import random
import glob
from sklearn.cluster import KMeans
import matplotlib.pyplot as plt
from scipy import ndimage
# 导入CTData类
from data.dataclass.CTData import CTData
# 立方体大小
CUBE_SIZE = 32
def variance_of_laplacian(image):
"""计算图像的Laplacian方差,用于评估图像清晰度/模糊度"""
# 计算图像的Laplacian
laplacian = ndimage.laplace(image)
# 返回Laplacian的方差
return np.var(laplacian)
def texture_score(cube):
"""计算立方体的纹理分数,用于评估是否包含足够的纹理特征"""
# 方法1: 使用Laplacian方差
lap_scores = []
for i in range(cube.shape[0]):
lap_scores.append(variance_of_laplacian(cube[i]))
lap_score = np.mean(lap_scores)
# 方法2: 使用梯度幅值
grad_x = ndimage.sobel(cube, axis=2)
grad_y = ndimage.sobel(cube, axis=1)
grad_z = ndimage.sobel(cube, axis=0)
grad_mag = np.sqrt(grad_x**2 + grad_y**2 + grad_z**2)
grad_score = np.mean(grad_mag)
# 方法3: 使用标准差(简单但有效)
std_score = np.std(cube)
# 组合分数
combined_score = (lap_score * 0.3) + (grad_score * 0.3) + (std_score * 0.4)
return combined_score
def is_valid_negative_sample(cube, lung_mask=None, nodule_coords=None, min_distance=20, min_texture_score=0.1):
"""
检查立方体是否是有效的负样本
Args:
cube: 32x32x32立方体数据
lung_mask: 肺部掩码,如果提供则检查是否在肺部内
nodule_coords: 已知结节坐标列表,如果提供则检查是否距离已知结节足够远
min_distance: 与已知结节的最小欧氏距离
min_texture_score: 最小纹理分数阈值
Returns:
布尔值,表示是否是有效的负样本
"""
# 计算纹理分数
score = texture_score(cube)
if score < min_texture_score:
return False
# 检查是否在肺部内
if lung_mask is not None:
if np.mean(lung_mask) < 0.6: # 要求至少60%的体素在肺部内
return False
# 检查是否距离已知结节足够远
if nodule_coords is not None and len(nodule_coords) > 0:
cube_center = np.array([CUBE_SIZE//2, CUBE_SIZE//2, CUBE_SIZE//2])
for nodule_coord in nodule_coords:
distance = np.linalg.norm(cube_center - np.array(nodule_coord))
if distance < min_distance:
return False
return True
def select_negative_samples_from_ct(ct_data, nodule_coords=None, num_samples=100, strategy='random'):
"""
从CT数据中选择负样本
Args:
ct_data: CTData对象
nodule_coords: 已知结节坐标列表
num_samples: 要选择的负样本数量
strategy: 选择策略,'random'(随机)或'kmeans'(聚类)
Returns:
负样本列表,每个样本为32x32x32的立方体
"""
if ct_data.lung_seg_img is None:
print("肺部区域数据未分割,正在分割...")
ct_data.filter_lung_img_mask()
lung_img = ct_data.lung_seg_img
lung_mask = ct_data.lung_seg_mask
# 获取肺部边界
z_indices, y_indices, x_indices = np.where(lung_mask > 0)
if len(z_indices) == 0:
print("未找到肺部区域")
return []
z_min, z_max = z_indices.min(), z_indices.max()
y_min, y_max = y_indices.min(), y_indices.max()
x_min, x_max = x_indices.min(), x_indices.max()
# 调整范围,确保可以放置完整的立方体
z_min = max(0, z_min)
y_min = max(0, y_min)
x_min = max(0, x_min)
z_max = min(lung_img.shape[0] - CUBE_SIZE, z_max)
y_max = min(lung_img.shape[1] - CUBE_SIZE, y_max)
x_max = min(lung_img.shape[2] - CUBE_SIZE, x_max)
# 收集候选点
if strategy == 'random':
# 随机选择策略
candidate_samples = []
attempts = 0
max_attempts = num_samples * 10 # 最大尝试次数
while len(candidate_samples) < num_samples and attempts < max_attempts:
# 随机选择起始点
z = random.randint(z_min, z_max)
y = random.randint(y_min, y_max)
x = random.randint(x_min, x_max)
# 提取立方体和肺部掩码
cube = lung_img[z:z+CUBE_SIZE, y:y+CUBE_SIZE, x:x+CUBE_SIZE]
cube_mask = lung_mask[z:z+CUBE_SIZE, y:y+CUBE_SIZE, x:x+CUBE_SIZE]
# 检查是否是有效的负样本
if is_valid_negative_sample(cube, cube_mask, nodule_coords):
candidate_samples.append({
'cube': cube,
'position': (z, y, x),
'score': texture_score(cube)
})
attempts += 1
# 根据纹理分数排序,选择最佳样本
candidate_samples.sort(key=lambda x: x['score'], reverse=True)
selected_samples = candidate_samples[:num_samples]
return [sample['cube'] for sample in selected_samples]
elif strategy == 'kmeans':
# 聚类选择策略(适用于大规模训练)
# 首先生成大量候选点
all_candidates = []
for _ in range(num_samples * 5):
z = random.randint(z_min, z_max)
y = random.randint(y_min, y_max)
x = random.randint(x_min, x_max)
cube = lung_img[z:z+CUBE_SIZE, y:y+CUBE_SIZE, x:x+CUBE_SIZE]
cube_mask = lung_mask[z:z+CUBE_SIZE, y:y+CUBE_SIZE, x:x+CUBE_SIZE]
if is_valid_negative_sample(cube, cube_mask, nodule_coords):
# 提取特征(使用简单的统计特征)
mean_val = np.mean(cube)
std_val = np.std(cube)
texture = texture_score(cube)
all_candidates.append({
'cube': cube,
'position': (z, y, x),
'features': [mean_val, std_val, texture],
'score': texture
})
if len(all_candidates) < num_samples:
print(f"警告: 只找到 {len(all_candidates)} 个候选样本,少于请求的 {num_samples} 个")
return [c['cube'] for c in all_candidates]
# 提取特征矩阵
features = np.array([c['features'] for c in all_candidates])
# 标准化特征
features = (features - features.mean(axis=0)) / features.std(axis=0)
# 使用KMeans聚类
n_clusters = min(num_samples, len(all_candidates))
kmeans = KMeans(n_clusters=n_clusters, random_state=42)
clusters = kmeans.fit_predict(features)
# 从每个簇中选择最佳样本
selected_samples = []
for i in range(n_clusters):
cluster_samples = [all_candidates[j] for j in range(len(all_candidates)) if clusters[j] == i]
if cluster_samples:
best_sample = max(cluster_samples, key=lambda x: x['score'])
selected_samples.append(best_sample)
return [sample['cube'] for sample in selected_samples]
else:
raise ValueError(f"不支持的选择策略: {strategy}")
def generate_negative_samples(ct_paths, output_dir, nodules_csv=None, samples_per_ct=50):
"""
从多个CT数据生成负样本
Args:
ct_paths: CT文件或目录路径列表
output_dir: 输出目录
nodules_csv: 包含已知结节信息的CSV文件路径
samples_per_ct: 每个CT数据生成的负样本数量
"""
os.makedirs(output_dir, exist_ok=True)
# 加载已知结节信息
nodule_coords_by_patient = {}
if nodules_csv and os.path.exists(nodules_csv):
import pandas as pd
nodules_df = pd.read_csv(nodules_csv)
for _, row in nodules_df.iterrows():
patient_id = row['patient_id']
if patient_id not in nodule_coords_by_patient:
nodule_coords_by_patient[patient_id] = []
nodule_coords_by_patient[patient_id].append(
(row['voxel_x'], row['voxel_y'], row['voxel_z'])
)
# 处理每个CT数据
for ct_path in ct_paths:
try:
# 获取患者ID
if os.path.isfile(ct_path):
patient_id = os.path.splitext(os.path.basename(ct_path))[0]
else:
patient_id = os.path.basename(ct_path)
print(f"处理患者 {patient_id}...")
# 加载CT数据
if os.path.isfile(ct_path) and ct_path.endswith('.mhd'):
ct_data = CTData.from_mhd(ct_path)
elif os.path.isdir(ct_path):
ct_data = CTData.from_dicom(ct_path)
else:
print(f"跳过不支持的文件类型: {ct_path}")
continue
# 获取已知结节坐标
nodule_coords = nodule_coords_by_patient.get(patient_id, None)
# 选择负样本
negative_samples = select_negative_samples_from_ct(
ct_data,
nodule_coords=nodule_coords,
num_samples=samples_per_ct,
strategy='random' # 或 'kmeans'
)
# 保存负样本
for i, sample in enumerate(negative_samples):
output_path = os.path.join(output_dir, f"{patient_id}_neg_{i:03d}.npy")
np.save(output_path, sample)
print(f"已为患者 {patient_id} 生成 {len(negative_samples)} 个负样本")
except Exception as e:
print(f"处理 {ct_path} 时出错: {str(e)}")
print("负样本生成完成!")
def visualize_samples(samples, output_path=None, cols=5):
"""可视化样本立方体,用于质量检查"""
rows = (len(samples) + cols - 1) // cols
fig, axes = plt.subplots(rows, cols, figsize=(cols*3, rows*3))
axes = axes.flatten()
for i, sample in enumerate(samples):
if i < len(axes):
# 显示立方体中间切片
middle_slice = sample[sample.shape[0]//2]
axes[i].imshow(middle_slice, cmap='gray')
axes[i].set_title(f"Sample {i}")
axes[i].axis('off')
# 隐藏多余的子图
for i in range(len(samples), len(axes)):
axes[i].axis('off')
plt.tight_layout()
if output_path:
plt.savefig(output_path)
plt.close()
else:
plt.show()
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description='负样本选择工具')
parser.add_argument('--input', type=str, required=True, help='输入CT文件或目录,或包含多个CT路径的文本文件')
parser.add_argument('--output', type=str, required=True, help='输出目录')
parser.add_argument('--nodules', type=str, default=None, help='包含已知结节信息的CSV文件')
parser.add_argument('--num-samples', type=int, default=50, help='每个CT数据生成的负样本数量')
parser.add_argument('--strategy', type=str, default='random', choices=['random', 'kmeans'], help='样本选择策略')
args = parser.parse_args()
# 处理输入参数
if os.path.isfile(args.input) and args.input.endswith('.txt'):
# 从文本文件读取CT路径列表
with open(args.input, 'r') as f:
ct_paths = [line.strip() for line in f if line.strip()]
else:
# 单个CT路径
ct_paths = [args.input]
# 生成负样本
generate_negative_samples(
ct_paths,
args.output,
nodules_csv=args.nodules,
samples_per_ct=args.num_samples
)
================================================
FILE: inference/pytorch_nodule_detector.py
================================================
import os
import numpy as np
import pandas as pd
import torch
from torch.nn import functional as F
from datetime import datetime
import logging
import time
from scipy import ndimage
from data.dataclass.CTData import CTData
from data.dataclass.NoduleCube import normal_cube_to_tensor
from data.preprocessing.luna16_invalid_nodule_filter import nodule_valid
from models.pytorch_c3d_tiny import C3dTiny
# 推理参数
CUBE_SIZE = 32 # 扫描立方体大小 32x32x32
SCAN_STEP = 10 # 扫描步长,每次移动10个像素
PROB_THRESHOLD = 0.8 # 阈值: 大于此概率才视为结节
# 设置日志
def setup_logger(log_dir="./inference_logs"):
"""设置日志配置"""
os.makedirs(log_dir, exist_ok=True)
log_file = os.path.join(log_dir, f"inference_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log")
# 创建logger
logger = logging.getLogger('nodule_detection')
logger.setLevel(logging.INFO)
# 创建文件处理器
file_handler = logging.FileHandler(log_file)
file_handler.setLevel(logging.INFO)
# 创建控制台处理器
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
# 创建格式器
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
file_handler.setFormatter(formatter)
console_handler.setFormatter(formatter)
# 添加处理器
logger.addHandler(file_handler)
logger.addHandler(console_handler)
return logger
def load_ct_data(file_path):
"""
加载CT数据(支持MHD和DICOM)并进行预处理
Args:
file_path: CT文件或文件夹路径
Returns:
CTData对象
"""
# 判断是文件还是目录
if os.path.isfile(file_path):
# 假设是MHD文件
if file_path.endswith('.mhd'):
ct_data = CTData.from_mhd(file_path)
else:
raise ValueError(f"不支持的文件类型: {file_path}")
elif os.path.isdir(file_path):
# 假设是DICOM文件夹
ct_data = CTData.from_dicom(file_path)
else:
raise ValueError(f"指定路径不存在: {file_path}")
# 重采样到1mm间距
ct_data = ct_data.resample_pixel(new_spacing=[1, 1, 1])
# 肺部区域分割
ct_data.filter_lung_img_mask()
return ct_data
def load_model(evaL_model_path, device='cuda'):
"""
加载PyTorch模型
Args:
evaL_model_path: 模型权重文件路径
device: 计算设备 ('cuda' 或 'cpu')
Returns:
加载好权重的模型
"""
model = C3dTiny().to(device)
# 加载权重
model.load_state_dict(torch.load(evaL_model_path, map_location=device))
model.eval()
return model, device
def get_lung_bounds(lung_mask):
"""获取肺部掩码的边界框,考虑左右肺分离的情况"""
if lung_mask.sum() == 0:
return None
# 使用连通区域分析找出肺部区域
labeled_mask, num_features = ndimage.label(lung_mask > 0)
# 如果连通区域过多,只考虑最大的几个区域(通常是左右肺)
if num_features > 2:
# 计算每个标签区域的体素数量
region_sizes = np.array([(labeled_mask == i).sum() for i in range(1, num_features + 1)])
# 只保留最大的2个区域(左右肺)
valid_labels = np.argsort(region_sizes)[-2:] + 1
# 创建新的掩码,只包含最大的几个区域
refined_mask = np.zeros_like(labeled_mask)
for label in valid_labels:
refined_mask[labeled_mask == label] = 1
else:
refined_mask = lung_mask > 0
# 根据z轴切片,计算每个切片的肺部区域
z_ranges = []
margin = 5 # 切片边距
# 遍历每个z轴切片
for z in range(refined_mask.shape[0]):
slice_mask = refined_mask[z]
if slice_mask.sum() > 100: # 如果切片包含足够的肺部体素
y_indices, x_indices = np.where(slice_mask)
if len(y_indices) > 0:
y_min = max(0, y_indices.min() - margin)
y_max = min(refined_mask.shape[1], y_indices.max() + margin)
x_min = max(0, x_indices.min() - margin)
x_max = min(refined_mask.shape[2], x_indices.max() + margin)
z_ranges.append((z, y_min, y_max, x_min, x_max))
if not z_ranges:
return None
# 确定整体z轴范围
z_min = z_ranges[0][0]
z_max = z_ranges[-1][0] + 1
# 收集所有y和x范围
scan_regions = []
for z_slice, y_min, y_max, x_min, x_max in z_ranges:
scan_regions.append({
'z': z_slice,
'y_min': y_min,
'y_max': y_max,
'x_min': x_min,
'x_max': x_max
})
return {
'z_min': z_min,
'z_max': z_max,
'regions': scan_regions
}
def scan_ct_data(ct_data, model, device, logger, step=SCAN_STEP):
"""
扫描整个CT图像,预测结节位置 - 优化版
Args:
ct_data: CTData对象
model: PyTorch模型
device: 计算设备
logger: 日志对象
step: 扫描步长
Returns:
包含结节信息的DataFrame
"""
logger.info("开始扫描CT数据...")
# 获取肺部分割后的图像数据
lung_img = ct_data.lung_seg_img
lung_mask = ct_data.lung_seg_mask
# 获取肺部边界信息
bounds = get_lung_bounds(lung_mask)
if bounds is None:
logger.warning("未能找到有效的肺部区域")
return pd.DataFrame(columns=['voxel_coord_x', 'voxel_coord_y', 'voxel_coord_z',
'world_coord_x', 'world_coord_y', 'world_coord_z', 'prob'])
logger.info(f"已确定肺部区域: Z轴范围 {bounds['z_min']} 到 {bounds['z_max']}, 共 {len(bounds['regions'])} 个切片")
# 创建存储结果的列表
results = []
# 计算需要扫描的总体素数估计
total_voxels = 0
for region in bounds['regions']:
y_range = region['y_max'] - region['y_min']
x_range = region['x_max'] - region['x_min']
total_voxels += (y_range // step + 1) * (x_range // step + 1)
logger.info(f"预计扫描体素数: {total_voxels}")
# 开始计时
start_time = time.time()
batch_size = 32 # 增大批处理大小提高GPU利用率
batch_inputs = []
batch_positions = []
# 跟踪进度
processed_voxels = 0
skipped_voxels = 0
# 设置肺部组织比例阈值
lung_tissue_threshold = 0.1 # 立方体中肺部组织的最小比例
# 逐切片扫描肺部区域
for z_idx, region in enumerate(bounds['regions']):
z = region['z']
# 检查是否可以放置一个完整的立方体
if z + CUBE_SIZE > lung_img.shape[0]:
continue
# 在当前切片上扫描
for y in range(region['y_min'], region['y_max'] - CUBE_SIZE + 1, step):
for x in range(region['x_min'], region['x_max'] - CUBE_SIZE + 1, step):
# 提取当前位置的肺部掩码立方体
mask_cube = lung_mask[z:z+CUBE_SIZE, y:y+CUBE_SIZE, x:x+CUBE_SIZE]
# 计算肺部组织比例
lung_ratio = np.mean(mask_cube)
# 如果肺部组织比例过低,跳过
if lung_ratio < lung_tissue_threshold:
skipped_voxels += 1
continue
# 提取当前位置的立方体
cube = lung_img[z:z+CUBE_SIZE, y:y+CUBE_SIZE, x:x+CUBE_SIZE]
# 预处理立方体数据
cube_tensor = normal_cube_to_tensor(cube)
# 转化为 batch
cube_tensor = cube_tensor.unsqueeze(0)
# 添加到批处理
batch_inputs.append(cube_tensor)
batch_positions.append((z, y, x))
# 当批处理达到指定大小时进行预测
if len(batch_inputs) == batch_size:
# 处理当前批次
process_batch(batch_inputs, batch_positions, model, device, ct_data, results)
batch_inputs = []
batch_positions = []
processed_voxels += 1
# 定期报告进度
if (processed_voxels + skipped_voxels) % 1000 == 0:
elapsed_time = time.time() - start_time
progress = processed_voxels / total_voxels * 100 if total_voxels > 0 else 0
logger.info(f"处理进度: {processed_voxels}/{total_voxels} ({progress:.2f}%), "
f"已跳过: {skipped_voxels}, 耗时: {elapsed_time:.2f}秒")
# 处理最后一个批次
if batch_inputs:
process_batch(batch_inputs, batch_positions, model, device, ct_data, results)
# 创建DataFrame
if results:
results_df = pd.DataFrame(results)
logger.info(f"扫描完成! 发现 {len(results_df)} 个可能的结节")
else:
results_df = pd.DataFrame(columns=['voxel_coord_x', 'voxel_coord_y', 'voxel_coord_z',
'world_coord_x', 'world_coord_y', 'world_coord_z', 'prob'])
logger.info("扫描完成! 未发现任何结节")
return results_df
def process_batch(batch_inputs, batch_positions, model, device, ct_data, results):
"""处理一个批次的数据"""
# 合并批处理
batch_tensor = torch.cat(batch_inputs, dim=0).to(device)
# 预测
with torch.no_grad():
batch_outputs = model(batch_tensor)
batch_probs = F.softmax(batch_outputs, dim=1)[:, 1] # 类别1的概率
# 处理每个预测结果
for i, prob in enumerate(batch_probs):
prob_value = prob.item()
if prob_value > PROB_THRESHOLD:
z_pos, y_pos, x_pos = batch_positions[i]
# 计算中心点坐标
center_z = z_pos + CUBE_SIZE // 2
center_y = y_pos + CUBE_SIZE // 2
center_x = x_pos + CUBE_SIZE // 2
# 将体素坐标转换为世界坐标 (mm)
world_coord = ct_data.voxel_to_world([center_x, center_y, center_z])
# 添加结果
results.append({
'voxel_coord_x': center_x,
'voxel_coord_y': center_y,
'voxel_coord_z': center_z,
'world_coord_x': world_coord[0],
'world_coord_y': world_coord[1],
'world_coord_z': world_coord[2],
'prob': prob_value
})
def reduce_overlapping_nodules(results_df, distance_threshold=15):
"""
合并重叠的结节预测,使用更严格的距离阈值
Args:
results_df: 包含结节预测的DataFrame
distance_threshold: 合并的距离阈值(体素)
Returns:
合并后的结节DataFrame
"""
if len(results_df) <= 1:
return results_df
# 按概率从高到低排序
sorted_df = results_df.sort_values('prob', ascending=False).reset_index(drop=True)
# 创建一个布尔掩码来标记要保留的行
keep_mask = np.ones(len(sorted_df), dtype=bool)
# 对每一行
for i in range(len(sorted_df)):
if not keep_mask[i]:
continue # 如果此行已被标记为删除,则跳过
# 获取当前结节的坐标
current = sorted_df.iloc[i]
# 比较与其他所有结节的距离
for j in range(i + 1, len(sorted_df)):
if not keep_mask[j]:
continue # 如果要比较的行已被标记为删除,则跳过
# 获取要比较的结节坐标
compare = sorted_df.iloc[j]
# 计算3D欧氏距离
distance = np.sqrt(
(current['voxel_coord_x'] - compare['voxel_coord_x']) ** 2 +
(current['voxel_coord_y'] - compare['voxel_coord_y']) ** 2 +
(current['voxel_coord_z'] - compare['voxel_coord_z']) ** 2
)
# 如果距离小于阈值,标记为删除
if distance < distance_threshold:
keep_mask[j] = False
# 应用掩码,仅保留未被标记为删除的行
reduced_df = sorted_df[keep_mask].reset_index(drop=True)
return reduced_df
def filter_false_positives(nodules_df, ct_data, max_nodules=10):
"""
基于解剖学和统计特征过滤假阳性结节
Args:
nodules_df: 包含结节预测的DataFrame
ct_data: CTData对象
max_nodules: 每个患者允许的最大结节数量
Returns:
过滤后的结节DataFrame
"""
if nodules_df.empty:
return nodules_df
# 获取肺部掩码
lung_mask = ct_data.lung_seg_mask
# 1. 限制结节总数
if len(nodules_df) > max_nodules:
# 只保留概率最高的前N个结节
nodules_df = nodules_df.sort_values('prob', ascending=False).head(max_nodules)
# 2. 基于位置过滤
filtered_rows = []
for i, row in nodules_df.iterrows():
x, y, z = int(row['voxel_coord_x']), int(row['voxel_coord_y']), int(row['voxel_coord_z'])
this_nodule_valid = nodule_valid(ct_data, x, y, z)
if this_nodule_valid:
# 通过所有检查,保留此结节
filtered_rows.append(row)
# 创建新的DataFrame
filtered_df = pd.DataFrame(filtered_rows)
# 3. 基于概率再次过滤
# 如果概率低于阈值,移除
# high_prob_threshold = 0.95 # 高概率阈值
# filtered_df = filtered_df[filtered_df['prob'] >= high_prob_threshold]
return filtered_df
def format_results(results_df, ct_data, patient_id):
"""
格式化结果为最终输出的DataFrame
Args:
results_df: 合并后的结节DataFrame
ct_data: CTData对象
patient_id: 患者ID
Returns:
包含结节信息的最终DataFrame
"""
# 如果没有结节,返回空的DataFrame
if results_df.empty:
return pd.DataFrame(columns=['patient_id', 'nodule_id', 'voxel_x', 'voxel_y', 'voxel_z',
'world_x', 'world_y', 'world_z', 'diameter_mm', 'prob'])
# 创建最终结果列表
final_results = []
# 处理每个结节
for i, row in results_df.iterrows():
# 设置默认直径为CUBE_SIZE / 2
diameter_mm = CUBE_SIZE / 2
# 添加结果
final_results.append({
'patient_id': patient_id,
'nodule_id': i + 1,
'voxel_x': int(row['voxel_coord_x']),
'voxel_y': int(row['voxel_coord_y']),
'voxel_z': int(row['voxel_coord_z']),
'world_x': row['world_coord_x'],
'world_y': row['world_coord_y'],
'world_z': row['world_coord_z'],
'diameter_mm': diameter_mm,
'prob': row['prob']
})
# 创建DataFrame
final_df = pd.DataFrame(final_results)
return final_df
def detect_nodules(file_path, model_path, detect_patient_id=None, device='cuda'):
"""
主函数:对CT数据进行结节检测
Args:
file_path: CT文件或文件夹路径
model_path: 模型权重文件路径
detect_patient_id: 患者ID,如果为None则使用文件名
device: 计算设备 ('cuda' 或 'cpu')
Returns:
包含结节信息的DataFrame
"""
# 设置日志
logger = setup_logger()
# 如果患者ID为None,则使用文件名
if detect_patient_id is None:
if os.path.isfile(file_path):
detect_patient_id = os.path.splitext(os.path.basename(file_path))[0]
else:
detect_patient_id = os.path.basename(file_path)
logger.info(f"开始处理患者 {detect_patient_id} 的CT数据")
try:
# 加载CT数据
logger.info(f"加载CT数据: {file_path}")
ct_data = load_ct_data(file_path)
# 加载模型
logger.info(f"加载模型: {model_path}")
model, device = load_model(model_path, device)
# 扫描CT数据
results_df = scan_ct_data(ct_data, model, device, logger)
# 合并重叠结节
logger.info("合并重叠结节...")
reduced_df = reduce_overlapping_nodules(results_df)
logger.info(f"合并后的结节数量: {len(reduced_df)}")
# 过滤假阳性
logger.info("过滤假阳性结节...")
filtered_df = filter_false_positives(reduced_df, ct_data)
logger.info(f"过滤后的结节数量: {len(filtered_df)}")
# 格式化结果
final_df = format_results(filtered_df, ct_data, patient_id)
logger.info(f"检测完成,找到 {len(final_df)} 个结节")
return final_df
except Exception as e:
logger.error(f"检测过程中出错: {str(e)}", exc_info=True)
raise
if __name__ == "__main__":
test_mhd = "H:/luna16/subset8/1.3.6.1.4.1.14519.5.2.1.6279.6001.149041668385192796520281592139.mhd"
model_path = "../training/pytorch_checkpoints/best_model.pth"
threshold = 0.7
patient_id = "1.3.6.1.4.1.14519.5.2.1.6279.6001.149041668385192796520281592139"
detect_result_csv = "./c3d_classify_result-%s.csv" %patient_id
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 运行检测
result_df = detect_nodules(test_mhd, model_path, None, device)
# 保存结果
result_df.to_csv(detect_result_csv, index=False, encoding="utf-8")
================================================
FILE: models/__init__.py
================================================
================================================
FILE: models/pytorch_c3d_tiny.py
================================================
import torch.nn as nn
import torchvision.transforms as transforms
# my_tranform =transforms.Compose([
# # transforms.Resize((32,32,32)),
# transforms.ToTensor(),
# transforms.Normalize((0.5,0.5,0.5), (0.5, 0.5,0.5))
# ])
class C3dTiny(nn.Module):
def __init__(self):
super().__init__()
# 第一个3d卷积组
self.conv_block1 = nn.Sequential(
nn.Conv3d(in_channels=1, kernel_size=3, padding = 1, out_channels=64),
# 原网络结构没有,新增的
nn.BatchNorm3d(64),
nn.ReLU(),
nn.MaxPool3d(kernel_size=(1,2,2), stride = (1,2,2))
)
#
self.conv_block2 = nn.Sequential(
nn.Conv3d(in_channels=64, kernel_size=3, padding = 1, out_channels=128),
# 原网络结构没有,新增的
nn.BatchNorm3d(128),
nn.ReLU(),
nn.MaxPool3d(kernel_size=2)
)
self.drop_out1 = nn.Dropout(0.2)
#
self.conv_block3 = nn.Sequential(
nn.Conv3d(in_channels = 128, kernel_size=3, padding = 1, out_channels=256),
nn.BatchNorm3d(256),
nn.ReLU(),
nn.Conv3d(in_channels=256, kernel_size=3, padding = 1, out_channels=256),
nn.BatchNorm3d(256),
nn.ReLU(),
nn.MaxPool3d(kernel_size=2)
)
self.drop_out2 = nn.Dropout(0.2)
#
self.conv_block4 = nn.Sequential(
nn.Conv3d(in_channels = 256, kernel_size = 3, padding = 1, out_channels=512),
nn.BatchNorm3d(512),
nn.ReLU(),
nn.Conv3d(in_channels = 512, kernel_size = 3, padding = 1, out_channels = 512),
nn.BatchNorm3d(512),
nn.ReLU(),
nn.MaxPool3d(kernel_size=2)
)
self.drop_out3 = nn.Dropout(0.2)
self.flatten = nn.Flatten()
#计算输入特征数量:
# 原始输入为32x32x32,经过pool1(1,2,2)后变为32x16x16
# 经过pool2(2,2,2)后变为16x8x8
# 经过pool3(2,2,2)后变为8x4x4
# 经过pool4(2,2,2)后变为4x2x2
# 因此最终特征图大小为4x2x2,通道数为512
self.fc1 = nn.Sequential(
nn.Linear(512 * 4 * 2 * 2, 512),
nn.ReLU()
)
self.fc2 = nn.Linear(512, 2)
def forward(self, x):
x = self.conv_block1(x)
x = self.conv_block2(x)
x = self.drop_out1(x)
x = self.conv_block3(x)
x = self.drop_out2(x)
x = self.conv_block4(x)
x = self.drop_out3(x)
x = self.flatten(x)
x = self.fc1(x)
x = self.fc2(x)
return x
================================================
FILE: training/__init__.py
================================================
================================================
FILE: training/pytorch_logs/training_20250331_223230.log
================================================
2025-03-31 22:32:30,594 - c3d_training - INFO - 训练集: 23640个样本
2025-03-31 22:32:30,594 - c3d_training - INFO - 验证集: 5910个样本
2025-03-31 22:32:30,749 - c3d_training - INFO - 模型结构:
C3dTiny(
(conv_block1): Sequential(
(0): Conv3d(1, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
(1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
(3): MaxPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2), padding=0, dilation=1, ceil_mode=False)
)
(conv_block2): Sequential(
(0): Conv3d(64, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
(1): BatchNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
(3): MaxPool3d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(drop_out1): Dropout(p=0.2, inplace=False)
(conv_block3): Sequential(
(0): Conv3d(128, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
(1): BatchNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
(3): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
(4): BatchNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU()
(6): MaxPool3d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(drop_out2): Dropout(p=0.2, inplace=False)
(conv_block4): Sequential(
(0): Conv3d(256, 512, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
(1): BatchNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
(3): Conv3d(512, 512, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
(4): BatchNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU()
(6): MaxPool3d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(drop_out3): Dropout(p=0.2, inplace=False)
(flatten): Flatten(start_dim=1, end_dim=-1)
(fc1): Sequential(
(0): Linear(in_features=8192, out_features=512, bias=True)
(1): ReLU()
)
(fc2): Linear(in_features=512, out_features=2, bias=True)
)
2025-03-31 22:33:51,563 - c3d_training - INFO - Epoch [1/15], Train Loss: 0.3146, Train Acc: 89.05%, Val Loss: 0.1623, Val Acc: 93.28%, Time: 80.81s
2025-03-31 22:33:51,652 - c3d_training - INFO - Epoch [1]: 保存最佳模型, 验证准确率: 93.28%
2025-03-31 22:35:12,255 - c3d_training - INFO - Epoch [2/15], Train Loss: 0.0834, Train Acc: 97.17%, Val Loss: 0.0728, Val Acc: 97.33%, Time: 80.39s
2025-03-31 22:35:12,341 - c3d_training - INFO - Epoch [2]: 保存最佳模型, 验证准确率: 97.33%
2025-03-31 22:36:33,755 - c3d_training - INFO - Epoch [3/15], Train Loss: 0.0504, Train Acc: 98.30%, Val Loss: 0.0383, Val Acc: 98.61%, Time: 81.20s
2025-03-31 22:36:33,843 - c3d_training - INFO - Epoch [3]: 保存最佳模型, 验证准确率: 98.61%
2025-03-31 22:37:55,955 - c3d_training - INFO - Epoch [4/15], Train Loss: 0.0368, Train Acc: 98.80%, Val Loss: 0.0522, Val Acc: 98.22%, Time: 81.89s
2025-03-31 22:39:16,050 - c3d_training - INFO - Epoch [5/15], Train Loss: 0.0282, Train Acc: 99.04%, Val Loss: 0.0322, Val Acc: 98.98%, Time: 79.84s
2025-03-31 22:39:16,142 - c3d_training - INFO - Epoch [5]: 保存最佳模型, 验证准确率: 98.98%
2025-03-31 22:40:37,157 - c3d_training - INFO - Epoch [6/15], Train Loss: 0.0244, Train Acc: 99.13%, Val Loss: 0.0393, Val Acc: 98.83%, Time: 80.80s
2025-03-31 22:41:59,029 - c3d_training - INFO - Epoch [7/15], Train Loss: 0.0204, Train Acc: 99.40%, Val Loss: 0.0383, Val Acc: 98.95%, Time: 81.64s
2025-03-31 22:43:20,016 - c3d_training - INFO - Epoch [8/15], Train Loss: 0.0219, Train Acc: 99.31%, Val Loss: 0.0578, Val Acc: 98.29%, Time: 80.75s
2025-03-31 22:44:41,369 - c3d_training - INFO - Epoch [9/15], Train Loss: 0.0171, Train Acc: 99.52%, Val Loss: 0.0342, Val Acc: 99.17%, Time: 81.13s
2025-03-31 22:44:41,473 - c3d_training - INFO - Epoch [9]: 保存最佳模型, 验证准确率: 99.17%
2025-03-31 22:46:02,745 - c3d_training - INFO - Epoch [10/15], Train Loss: 0.0173, Train Acc: 99.49%, Val Loss: 0.0279, Val Acc: 99.15%, Time: 81.04s
2025-03-31 22:47:24,128 - c3d_training - INFO - Epoch [11/15], Train Loss: 0.0176, Train Acc: 99.47%, Val Loss: 0.0349, Val Acc: 99.17%, Time: 81.15s
2025-03-31 22:48:44,717 - c3d_training - INFO - Epoch [12/15], Train Loss: 0.0140, Train Acc: 99.53%, Val Loss: 0.0362, Val Acc: 99.27%, Time: 80.36s
2025-03-31 22:48:44,807 - c3d_training - INFO - Epoch [12]: 保存最佳模型, 验证准确率: 99.27%
2025-03-31 22:50:07,041 - c3d_training - INFO - Epoch [13/15], Train Loss: 0.0126, Train Acc: 99.62%, Val Loss: 0.0497, Val Acc: 98.87%, Time: 82.02s
2025-03-31 22:51:27,896 - c3d_training - INFO - Epoch [14/15], Train Loss: 0.0152, Train Acc: 99.49%, Val Loss: 0.0271, Val Acc: 99.32%, Time: 80.61s
2025-03-31 22:51:27,987 - c3d_training - INFO - Epoch [14]: 保存最佳模型, 验证准确率: 99.32%
2025-03-31 22:52:49,409 - c3d_training - INFO - Epoch [15/15], Train Loss: 0.0127, Train Acc: 99.58%, Val Loss: 0.0277, Val Acc: 99.15%, Time: 81.20s
2025-03-31 22:52:49,737 - c3d_training - INFO - 训练完成,最终模型已保存
2025-03-31 22:52:50,003 - c3d_training - INFO - 训练完成!
================================================
FILE: training/train_c3d_pytorch.py
================================================
import os
import glob
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from data.dataclass.NoduleCube import normal_cube_to_tensor
plt.rcParams['font.family'] = ['SimHei'] # 设置字体为黑体
plt.rcParams['axes.unicode_minus'] = False # 正确显示负号
import logging
import time
from datetime import datetime
from models.pytorch_c3d_tiny import C3dTiny
# 模型参数
BATCH_SIZE = 64
EPOCHS = 15
LEARNING_RATE = 5e-4
WEIGHT_DECAY = 1e-5
CUBE_SIZE = 32
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 设置日志
def setup_logger(log_dir="./pytorch_logs"):
os.makedirs(log_dir, exist_ok=True)
log_file = os.path.join(log_dir, f"training_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log")
# 创建logger
logger = logging.getLogger('c3d_training')
logger.setLevel(logging.INFO)
# 创建文件处理器
file_handler = logging.FileHandler(log_file, encoding="utf-8")
file_handler.setLevel(logging.INFO)
# 创建控制台处理器
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
# 创建格式器
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
file_handler.setFormatter(formatter)
console_handler.setFormatter(formatter)
# 添加处理器
logger.addHandler(file_handler)
logger.addHandler(console_handler)
return logger
class Luna16DataSet(Dataset):
def __init__(self, files, labels, tranform =None):
self.files = files
self.labels = labels
self.transform = tranform
def __len__(self):
return len(self.files)
def __getitem__(self, idx):
npy_file = self.files[idx]
item_label = self.labels[idx]
item_label = torch.tensor(item_label, dtype=torch.long)
item_data = np.load(npy_file)
torch_item_data = normal_cube_to_tensor(item_data)
if self.transform is not None:
torch_item_data = self.transform(torch_item_data)
return torch_item_data,item_label
def load_train_val_data(postive_dir, negative_dir):
"""
从文件夹加载训练集和验证集
:param postive_dir:
:param negative_dir:
:return:
"""
postive_files = glob.glob(os.path.join(postive_dir, "*.npy"))
negative_files = glob.glob(os.path.join(negative_dir, "*.npy"))
min_samples = min(len(postive_files) ,len(negative_files))
pos_files = random.sample(postive_files ,min_samples)
neg_files = random.sample(negative_files, 2*min_samples)
all_files = pos_files + neg_files
labels = np.concatenate([np.ones(len(pos_files)), np.zeros(len(neg_files))])
indices = np.arange(len(all_files))
np.random.shuffle(indices)
files_train,files_val,label_train,label_val = train_test_split(all_files, labels,test_size=0.2, random_state=42)
return files_train,files_val,label_train,label_val
def train_model(model,train_loader, val_loader, optimizer, criterion, scheduler, logger, writer, epoches, save_dir):
best_val_acc = 0.0
train_losses = []
val_losses = []
val_accs = []
train_accs = []
os.makedirs(save_dir, exist_ok= True)
for epoch in range(epoches):
model.train()
train_loss = 0.0
correct = 0
total = 0
start_time = time.time()
for i, (inputs, labels ) in enumerate(train_loader):
inputs = inputs.to(DEVICE)
labels = labels.to(DEVICE)
# 梯度清零
optimizer.zero_grad()
# 前向传播
outputs = model(inputs)
loss = criterion(outputs, labels)
# 检查损失值是否为NaN
if torch.isnan(loss).any() or torch.isinf(loss).any():
logger.warning(f"警告:损失值包含NaN或Inf,跳过此批次")
continue
# 后向传播loss
loss.backward()
# 梯度裁剪,防止梯度爆炸
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
#
train_loss +=loss.item()
_,predicteds = outputs.max(1)
total +=labels.size(0)
correct +=predicteds.eq(labels).sum().item()
# 打印批次
if (i + 1) % 100 == 0:
print(f"{epoch +1}/{epoches}, Batch [{i + 1}/ {len(train_loader)}], Loss: {loss.item():.4f}")
# 本次 epoch 平均训练损失
epoch_train_loss = train_loss / len(train_loader)
# 本次epoch 平均准确率
epoch_train_acc = 100.0 * correct/total
train_losses.append(epoch_train_loss)
train_accs.append(epoch_train_acc)
# 计算平均训练损失和准确率
model.eval()
val_loss = 0.0
val_correct = 0
val_total = 0
with torch.no_grad():
for val_inputs,val_labels in val_loader:
val_inputs = val_inputs.to(DEVICE)
val_labels = val_labels.to(DEVICE)
#
val_outputs = model(val_inputs)
batch_val_loss = criterion(val_outputs, val_labels)
val_loss += batch_val_loss.item()
_,predicted = val_outputs.max(1)
val_total += val_labels.size(0)
val_correct += predicted.eq(val_labels).sum().item()
# 计算平均验证损失和准确率
epoch_val_loss = val_loss / len(val_loader)
epoch_val_acc = 100.0 * val_correct / val_total
val_losses.append(epoch_val_loss)
val_accs.append(epoch_val_acc)
# 更新学习率
if scheduler is not None:
scheduler.step(epoch_val_loss)
# 记录到TensorBoard
writer.add_scalar('Loss/train', epoch_train_loss, epoch)
writer.add_scalar('Loss/val', epoch_val_loss, epoch)
writer.add_scalar('Accuracy/train', epoch_train_acc, epoch)
writer.add_scalar('Accuracy/val', epoch_val_acc, epoch)
writer.add_scalar('Learning_rate', optimizer.param_groups[0]['lr'], epoch)
end_time = time.time()
epoch_time = end_time - start_time
# 打印epoch信息
logger.info(f'Epoch [{epoch+1}/{epoches}], '
f'Train Loss: {epoch_train_loss:.4f}, Train Acc: {epoch_train_acc:.2f}%, '
f'Val Loss: {epoch_val_loss:.4f}, Val Acc: {epoch_val_acc:.2f}%, '
f'Time: {epoch_time:.2f}s')
# 保存最佳模型
if epoch_val_acc > best_val_acc:
best_val_acc = epoch_val_acc
torch.save(model.state_dict(), os.path.join(save_dir, 'best_model.pth'))
logger.info(f'Epoch [{epoch+1}]: 保存最佳模型, 验证准确率: {epoch_val_acc:.2f}%')
# 每5个epoch保存一次检查点
# if (epoch + 1) % 5 == 0:
torch.save({
'epoch': epoch + 1,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'train_loss': epoch_train_loss,
'val_loss': epoch_val_loss,
'train_acc': epoch_train_acc,
'val_acc': epoch_val_acc
}, os.path.join(save_dir, f'checkpoint_epoch_{epoch+1}.pth'))
# 保存最终模型
torch.save(model.state_dict(), os.path.join(save_dir, 'final_model.pth'))
logger.info(f'训练完成,最终模型已保存')
# 绘制损失和准确率曲线
plot_metrics(train_losses, val_losses, train_accs, val_accs, save_dir)
return train_losses, val_losses, train_accs, val_accs
def plot_metrics(train_losses, val_losses, train_accs, val_accs, save_dir):
"""绘制并保存损失和准确率曲线"""
plt.figure(figsize=(12, 5))
# 绘制损失曲线
plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Train Loss')
plt.plot(val_losses, label='Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.grid(True)
# 绘制准确率曲线
plt.subplot(1, 2, 2)
plt.plot(train_accs, label='Train Accuracy')
plt.plot(val_accs, label='Validation Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy (%)')
plt.title('Training and Validation Accuracy')
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.savefig(os.path.join(save_dir, 'training_metrics.png'))
plt.close()
def main():
# 设置随机种子
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)
# 创建目录
log_dir = "./pytorch_logs"
checkpoint_dir = "./pytorch_checkpoints"
os.makedirs(log_dir, exist_ok=True)
os.makedirs(checkpoint_dir, exist_ok=True)
# 设置日志
logger = setup_logger(log_dir)
# 设置TensorBoard
writer = SummaryWriter(log_dir=os.path.join(log_dir, 'tensorboard'))
# 加载数据
pos_sample_dir = r"J:\luna16_processed\positive_npys"
neg_sample_dir = r"J:\luna16_processed\negative_npys"
files_train, files_val, labels_train, labels_val = load_train_val_data(pos_sample_dir, neg_sample_dir)
logger.info(f"训练集: {len(files_train)}个样本")
logger.info(f"验证集: {len(files_val)}个样本")
# 创建数据集和数据加载器
train_dataset = Luna16DataSet(files_train, labels_train)
val_dataset = Luna16DataSet(files_val, labels_val)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)
# 创建模型
model = C3dTiny().to(DEVICE)
logger.info(f"模型结构:\n{model}")
# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY, eps=1e-8)
# 学习率调度器
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode='min', factor=0.5, patience=5, verbose=True, min_lr=1e-6
)
# 训练模型
train_losses, val_losses, train_accs, val_accs = train_model(
model=model,
train_loader=train_loader,
val_loader=val_loader,
optimizer=optimizer,
criterion=criterion,
scheduler=scheduler,
logger=logger,
writer=writer,
epoches=EPOCHS,
save_dir=checkpoint_dir
)
# 关闭TensorBoard writer
writer.close()
logger.info("训练完成!")
if __name__ == "__main__":
main()
================================================
FILE: util/__init__.py
================================================
================================================
FILE: util/dicom_util.py
================================================
import os
import glob
import pydicom
import numpy as np
import cv2
from tqdm import tqdm
from util.seg_util import get_segmented_lungs,normalize_hu_values
from util.image_util import rescale_patient_images
def is_dicom_file(filename):
'''
if current file is a dicom file
:param filename: file need to be judged
:return:
'''
file_stream = open(filename, 'rb')
file_stream.seek(128)
data = file_stream.read(4)
file_stream.close()
if data == b'DICM':
return True
return False
def get_dicom_thickness(dicom_slices):
"""
计算切片厚度
:param dicom_slices: dicom 读取的 dicom数据
:return:
"""
if len(dicom_slices) > 1:
try:
slice_thickness = abs(dicom_slices[0].ImagePositionPatient[2] - dicom_slices[1].ImagePositionPatient[2])
except:
try:
slice_thickness = abs(dicom_slices[0].SliceLocation - dicom_slices[1].SliceLocation)
except:
# 如果无法计算,尝试从SliceThickness标签中获取
try:
slice_thickness = float(dicom_slices[0].SliceThickness)
except:
print("警告: 无法确定切片厚度,使用默认值1.0mm")
slice_thickness = 1.0
else:
try:
slice_thickness = float(dicom_slices[0].SliceThickness)
except:
print("警告: 只有一个切片,无法计算切片厚度,使用默认值1.0mm")
slice_thickness = 1.0
return slice_thickness
def load_dicom_slices(dicom_path):
"""
load dicom file path and stack into list
:param dicom_path: a dicom path
:return: dicom list
"""
dicom_files = []
for root, _, files in os.walk(dicom_path):
for file in files:
if file.lower().endswith(('.dcm', '.dicom')):
real_file = os.path.join(dicom_path, root, file)
current_if_dicom = is_dicom_file(real_file)
if current_if_dicom:
dicom_files.append(real_file)
if not dicom_files:
raise ValueError(f"在路径 {dicom_path} 中未找到DICOM文件")
# 加载所有切片
slices = []
for file in dicom_files:
try:
ds = pydicom.dcmread(file)
slices.append(ds)
except Exception as e:
print(f"无法读取DICOM文件 {file}: {e}")
# # 按照Z轴位置排序切片
slices.sort(key=lambda x: int(x.InstanceNumber))
slice_thickness = get_dicom_thickness(slices)
for s in slices:
s.SliceThickness = slice_thickness
return slices
def get_pixels_hu(slices):
'''
transfer dicom array to pixel array,and remove border(HU==-2000)
:param slices: dicom list
:return: pixel array of one patient's dicom
'''
image = np.stack([s.pixel_array for s in slices])
image = image.astype(np.int16)
image[image == -2000] = 0
for slice_number in range(len(slices)):
intercept = slices[slice_number].RescaleIntercept
slope = slices[slice_number].RescaleSlope
if slope != 1:
image[slice_number] = slope * image[slice_number].astype(np.float64)
image[slice_number] = image[slice_number].astype(np.int16)
image[slice_number] += np.int16(intercept)
return np.array(image, dtype=np.int16)
def getinfo_dicom(dicom_path):
print('dicom_path: ', dicom_path)
slices = load_dicom_slices(dicom_path)
print(type(slices[0]), slices[0].ImagePositionPatient)
print(len(slices), "\t", slices[0].SliceThickness, "\t", slices[0].PixelSpacing)
print("Orientation: ", slices[0].ImageOrientationPatient)
#assert slices[0].ImageOrientationPatient == [1.000000, 0.000000, 0.000000, 0.000000, 1.000000, 0.000000]
pixels = get_pixels_hu(slices)
image = pixels
print(image.shape)
invert_order = slices[1].ImagePositionPatient[2] > slices[0].ImagePositionPatient[2]
print("Invert order: ", invert_order, " - ", slices[1].ImagePositionPatient[2], ",",
slices[0].ImagePositionPatient[2])
pixel_spacing = slices[0].PixelSpacing
pixel_spacing.append(slices[0].SliceThickness)
# save dicom source image size
dicom_size = [image.shape[0], image.shape[1], image.shape[2]]
return pixel_spacing, dicom_size, invert_order
def extract_dicom_images_patient(dicom_path, target_dir):
slices = load_dicom_slices(dicom_path)
assert slices[0].ImageOrientationPatient == [1.000000, 0.000000, 0.000000, 0.000000, 1.000000, 0.000000]
pixels = get_pixels_hu(slices)
image = pixels
invert_order = slices[1].ImagePositionPatient[2] > slices[0].ImagePositionPatient[2]
pixel_spacing = slices[0].PixelSpacing
pixel_spacing.append(slices[0].SliceThickness)
# save dicom source image size
dicom_size = [image.shape[0], image.shape[1], image.shape[2]]
image = rescale_patient_images(image, pixel_spacing)
png_size = [image.shape[0], image.shape[1], image.shape[2]]
if not invert_order:
image = np.flipud(image)
if not os.path.exists(target_dir):
os.mkdir(target_dir)
else:
print("png dir already exists, return directly")
return pixel_spacing, dicom_size, png_size, invert_order
png_files = glob.glob(target_dir + "*.png")
for file in png_files:
os.remove(file)
for i in tqdm(range(image.shape[0])):
img_path = patient_dir + "/img_" + str(i).rjust(4, '0') + "_i.png"
org_img = image[i]
img, mask = get_segmented_lungs(org_img.copy())
org_img = normalize_hu_values(org_img)
cv2.imwrite(img_path, org_img * 255)
cv2.imwrite(img_path.replace("_i.png", "_m.png"), mask * 255)
return pixel_spacing, dicom_size, png_size,invert_order
================================================
FILE: util/image_util.py
================================================
from typing import Tuple
import cv2
import os
import numpy
import glob
import random
import numpy as np
from scipy import ndimage
def get_normalized_img_unit8(img):
img = img.astype(numpy.float)
min = img.min()
max = img.max()
img -= min
img /= max - min
img *= 255
res = img.astype(numpy.uint8)
return res
def load_patient_images(png_path, wildcard="*.*", exclude_wildcards=[]):
print("png path is\t",png_path)
src_dir = png_path
src_img_paths = glob.glob(src_dir +'/'+ wildcard)
for exclude_wildcard in exclude_wildcards:
exclude_img_paths = glob.glob(src_dir + exclude_wildcard)
src_img_paths = [im for im in src_img_paths if im not in exclude_img_paths]
src_img_paths.sort()
images = [cv2.imread(img_path, cv2.IMREAD_GRAYSCALE) for img_path in src_img_paths]
images = [im.reshape((1, ) + im.shape) for im in images]
res = numpy.vstack(images)
return res
def draw_overlay(png_path: str, p_x: float, p_y: float, p_z: float, index: str, BOX_size:int = 20) -> None:
"""
在图像上绘制覆盖层
Args:
png_path: PNG图像路径
p_x: X坐标(百分比)
p_y: Y坐标(百分比)
p_z: Z坐标(百分比)
index: 索引标识
:param BOX_size:
"""
patient_img = load_patient_images(png_path + "/png/", "*_i.png", [])
z = int(p_z * patient_img.shape[0])
y = int(p_y * patient_img.shape[1])
x = int(p_x * patient_img.shape[2])
# 包围盒大小
x1 = x - BOX_size
y1 = y - BOX_size
x2 = x + BOX_size
y2 = y + BOX_size
target_img = patient_img[z, :, :]
cv2.rectangle(target_img, (x1, y1), (x2, y2), (255, 0, 0), 1)
cv2.imwrite(png_path + "/" + index + ".png", target_img)
def prepare_image_for_net3D(img,MEAN_PIXEL_VALUE = 41):
'''
normalization of image (average and zero center)
:param img: image to be normalization
:param MEAN_PIXEL_VALUE:
:return:
'''
img = img.astype(numpy.float32)
img -= MEAN_PIXEL_VALUE
img /= 255.
img = img.reshape(1, img.shape[0], img.shape[1], img.shape[2], 1)
return img
def move_png2dir(target_dir):
import shutil
first_dir = []
for path in os.listdir(target_dir):
if os.path.isdir(os.path.join(target_dir,path)):
first_dir.append(os.path.join(target_dir,path))
for d in first_dir:
tmp_path = []
for file in os.listdir(d):
tmp_file_path = os.path.join(d,file)
png_path = os.path.join(d,'png')
if not os.path.exists(png_path):
os.mkdir(png_path)
if tmp_file_path.endswith(".png"):
shutil.move(tmp_file_path,os.path.join(png_path,file))
print("move file from %s to %s " %(tmp_file_path,os.path.join(png_path,file)))
def rescale_patient_images(images_zyx, org_spacing_xyz, target_voxel_mm =1.0, is_mask_image=False, verbose=False):
'''
rescale a 3D image to specified size
:param images_zyx: source image
:param org_spacing_xyz:
:param target_voxel_mm:
:param is_mask_image:
:param verbose:
:return:
'''
if verbose:
print("Spacing: ", org_spacing_xyz)
print("Shape: ", images_zyx.shape)
# print "Resizing dim z"
resize_x = 1.0
resize_y = float(org_spacing_xyz[2]) / float(target_voxel_mm)
interpolation = cv2.INTER_NEAREST if is_mask_image else cv2.INTER_LINEAR
res = cv2.resize(images_zyx, dsize=None, fx=resize_x, fy=resize_y, interpolation=interpolation) # opencv assumes y, x, channels umpy array, so y = z pfff
res = res.swapaxes(0, 2)
res = res.swapaxes(0, 1)
# print "Shape: ", res.shape
resize_x = float(org_spacing_xyz[0]) / float(target_voxel_mm)
resize_y = float(org_spacing_xyz[1]) / float(target_voxel_mm)
# cv2 can handle max 512 channels..
if res.shape[2] > 512:
res = res.swapaxes(0, 2)
res1 = res[:256]
res2 = res[256:]
res1 = res1.swapaxes(0, 2)
res2 = res2.swapaxes(0, 2)
res1 = cv2.resize(res1, dsize=None, fx=resize_x, fy=resize_y, interpolation=interpolation)
res2 = cv2.resize(res2, dsize=None, fx=resize_x, fy=resize_y, interpolation=interpolation)
res1 = res1.swapaxes(0, 2)
res2 = res2.swapaxes(0, 2)
res = numpy.vstack([res1, res2])
res = res.swapaxes(0, 2)
else:
res = cv2.resize(res, dsize=None, fx=resize_x, fy=resize_y, interpolation=interpolation)
res = res.swapaxes(0, 2)
res = res.swapaxes(2, 1)
if verbose:
print("Shape after: ", res.shape)
return res
def rescale_patient_images2(images_zyx, target_shape, verbose=False):
if verbose:
print("Target: ", target_shape)
print("Shape: ", images_zyx.shape)
# print "Resizing dim z"
resize_x = 1.0
interpolation = cv2.INTER_NEAREST if False else cv2.INTER_LINEAR
res = cv2.resize(images_zyx, dsize=(target_shape[1], target_shape[0]), interpolation=interpolation) # opencv assumes y, x, channels umpy array, so y = z pfff
# print "Shape is now : ", res.shape
res = res.swapaxes(0, 2)
res = res.swapaxes(0, 1)
# cv2 can handle max 512 channels..
if res.shape[2] > 512:
res = res.swapaxes(0, 2)
res1 = res[:256]
res2 = res[256:]
res1 = res1.swapaxes(0, 2)
res2 = res2.swapaxes(0, 2)
res1 = cv2.resize(res1, dsize=(target_shape[2], target_shape[1]), interpolation=interpolation)
res2 = cv2.resize(res2, dsize=(target_shape[2], target_shape[1]), interpolation=interpolation)
res1 = res1.swapaxes(0, 2)
res2 = res2.swapaxes(0, 2)
res = numpy.vstack([res1, res2])
res = res.swapaxes(0, 2)
else:
res = cv2.resize(res, dsize=(target_shape[2], target_shape[1]), interpolation=interpolation)
res = res.swapaxes(0, 2)
res = res.swapaxes(2, 1)
if verbose:
print("Shape after: ", res.shape)
return res
def resize_image(image: np.ndarray, new_shape: Tuple[int, ...]) -> np.ndarray:
"""
调整图像大小
Args:
image: 输入图像
new_shape: 新形状
Returns:
np.ndarray: 调整大小后的图像
"""
# 处理单通道或多通道图像
if len(image.shape) == 3 and len(new_shape) == 2:
# 处理3D图像调整为2D
resized_image = np.zeros((image.shape[0], new_shape[0], new_shape[1]))
for i in range(image.shape[0]):
resized_image[i] = cv2.resize(image[i], (new_shape[1], new_shape[0]))
return resized_image
elif len(image.shape) == 2 and len(new_shape) == 2:
# 处理2D图像
return cv2.resize(image, (new_shape[1], new_shape[0]))
else:
# 处理任意维度图像
resize_factor = tuple(n / o for n, o in zip(new_shape, image.shape))
return ndimage.zoom(image, resize_factor, mode='nearest')
def cv_flip(img,cols,rows,degree):
'''
flip image by degree
:param img: image array to be fliped
:param cols: width of image
:param rows: height of image
:param degree: degree to flip
:return:
'''
M = cv2.getRotationMatrix2D((cols / 2, rows /2), degree, 1.0)
dst = cv2.warpAffine(img, M, (cols, rows))
return dst
def random_rotate_img(img, chance, min_angle, max_angle):
'''
random rotation an image
:param img: image to be rotated
:param chance: random probability
:param min_angle: min angle to rotate
:param max_angle: max angle to rotate
:return: image after random rotated
'''
import cv2
if random.random() > chance:
return img
if not isinstance(img, list):
img = [img]
angle = random.randint(min_angle, max_angle)
center = (img[0].shape[0] / 2, img[0].shape[1] / 2)
rot_matrix = cv2.getRotationMatrix2D(center, angle, scale=1.0)
res = []
for img_inst in img:
img_inst = cv2.warpAffine(img_inst, rot_matrix, dsize=img_inst.shape[:2], borderMode=cv2.BORDER_CONSTANT)
res.append(img_inst)
if len(res) == 0:
res = res[0]
return res
def random_flip_img(img, horizontal_chance=0, vertical_chance=0):
'''
random flip image,both on horizontal and vertical
:param img: image to be flipped
:param horizontal_chance: flip probability to flipped on horizontal direction
:param vertical_chance: flip probability to flipped on vertical direction
:return: image after flipped
'''
import cv2
flip_horizontal = False
if random.random() < horizontal_chance:
flip_horizontal = True
flip_vertical = False
if random.random() < vertical_chance:
flip_vertical = True
if not flip_horizontal and not flip_vertical:
return img
flip_val = 1
if flip_vertical:
flip_val = -1 if flip_horizontal else 0
if not isinstance(img, list):
res = cv2.flip(img, flip_val) # 0 = X axis, 1 = Y axis, -1 = both
else:
res = []
for img_item in img:
img_flip = cv2.flip(img_item, flip_val)
res.append(img_flip)
return res
def random_scale_img(img, xy_range, lock_xy=False):
if random.random() > xy_range.chance:
return img
if not isinstance(img, list):
img = [img]
import cv2
scale_x = random.uniform(xy_range.x_min, xy_range.x_max)
scale_y = random.uniform(xy_range.y_min, xy_range.y_max)
if lock_xy:
scale_y = scale_x
org_height, org_width = img[0].shape[:2]
xy_range.last_x = scale_x
xy_range.last_y = scale_y
res = []
for img_inst in img:
scaled_width = int(org_width * scale_x)
scaled_height = int(org_height * scale_y)
scaled_img = cv2.resize(img_inst, (scaled_width, scaled_height), interpolation=cv2.INTER_CUBIC)
if scaled_width < org_width:
extend_left = (org_width - scaled_width) / 2
extend_right = org_width - extend_left - scaled_width
scaled_img = cv2.copyMakeBorder(scaled_img, 0, 0, extend_left, extend_right, borderType=cv2.BORDER_CONSTANT)
scaled_width = org_width
if scaled_height < org_height:
extend_top = (org_height - scaled_height) / 2
extend_bottom = org_height - extend_top - scaled_height
scaled_img = cv2.copyMakeBorder(scaled_img, extend_top, extend_bottom, 0, 0, borderType=cv2.BORDER_CONSTANT)
scaled_height = org_height
start_x = (scaled_width - org_width) / 2
start_y = (scaled_height - org_height) / 2
tmp = scaled_img[start_y: start_y + org_height, start_x: start_x + org_width]
res.append(tmp)
return res
class XYRange:
def __init__(self, x_min, x_max, y_min, y_max, chance=1.0):
self.chance = chance
self.x_min = x_min
self.x_max = x_max
self.y_min = y_min
self.y_max = y_max
self.last_x = 0
self.last_y = 0
def get_last_xy_txt(self):
res = "x_" + str(int(self.last_x * 100)).replace("-", "m") + "-" + "y_" + str(int(self.last_y * 100)).replace(
"-", "m")
return res
def random_translate_img(img, xy_range, border_mode="constant"):
if random.random() > xy_range.chance:
return img
import cv2
if not isinstance(img, list):
img = [img]
org_height, org_width = img[0].shape[:2]
translate_x = random.randint(xy_range.x_min, xy_range.x_max)
translate_y = random.randint(xy_range.y_min, xy_range.y_max)
trans_matrix = numpy.float32([[1, 0, translate_x], [0, 1, translate_y]])
border_const = cv2.BORDER_CONSTANT
if border_mode == "reflect":
border_const = cv2.BORDER_REFLECT
res = []
for img_inst in img:
img_inst = cv2.warpAffine(img_inst, trans_matrix, (org_width, org_height), borderMode=border_const)
res.append(img_inst)
if len(res) == 1:
res = res[0]
xy_range.last_x = translate_x
xy_range.last_y = translate_y
return res
def data_augmentation(image: np.ndarray, augment_type: str = 'random') -> np.ndarray:
"""
对图像进行数据增强
Args:
image: 输入图像
augment_type: 增强类型,可选'random', 'flip', 'rotate', 'shift'
Returns:
np.ndarray: 增强后的图像
"""
if augment_type == 'random':
# 随机选择一种增强方式
augment_choices = ['flip', 'rotate', 'shift', 'none']
choice = np.random.choice(augment_choices)
if choice == 'flip':
return data_augmentation(image, 'flip')
elif choice == 'rotate':
return data_augmentation(image, 'rotate')
elif choice == 'shift':
return data_augmentation(image, 'shift')
else:
return image
elif augment_type == 'flip':
# 随机翻转
axis = np.random.randint(0, image.ndim)
return np.flip(image, axis=axis)
elif augment_type == 'rotate':
# 随机旋转
if image.ndim == 2:
angle = np.random.randint(0, 360)
return ndimage.rotate(image, angle, reshape=False, mode='nearest')
else:
# 3D旋转
axes = tuple(np.random.choice(range(image.ndim), size=2, replace=False))
angle = np.random.randint(0, 360)
return ndimage.rotate(image, angle, axes=axes, reshape=False, mode='nearest')
elif augment_type == 'shift':
# 随机平移
shift = np.random.randint(-5, 6, size=image.ndim)
return ndimage.shift(image, shift, mode='nearest')
return image
================================================
FILE: util/mhd_util.py
================================================
import os
import ntpath
import SimpleITK
import numpy as np
import pandas as pd
import cv2
from data.dataclass.CTData import CTData
from util.seg_util import normalize_hu_values,get_segmented_lungs
from util import image_util
from constant import tianchi
TARGET_VOXEL_MM = 1.0
MHD_INFO_HEAD = "patient_id,shape_0,shape_1,shape_2,origin_x,origin_y,origin_z,direction_z(1_-1)," \
"spacing_x,spacing_y,spacing_z,rescale_x,rescale_y,rescale_z"
def get_all_mhd_file(BASE_DATA_DIR,base_head,max):
"""
get all mhd file list ,tianchi mhd file consist of train_subset00,train_subset01,... test_subset00,test_subset01,..
:param base_head: 'train' or 'test',or 'val', to construct train_subset00,test_subset01,val_subset02...
:param max: the max suffix of path ,such as train_subset09, then max=09
:return: all mhd file list
"""
mhd_files = []
for index in range(0,max):
if index<10:
index = "0"+str(index)
else:
index =str(index)
sub_path = os.path.join(BASE_DATA_DIR,base_head+"_subset"+index)
for name in os.listdir(sub_path):
if name.endswith(".mhd"):
mhd_files.append(os.path.join(sub_path,name))
return mhd_files
def get_luna16_mhd_file(mhd_root):
"""
get all mhd file list ,tianchi mhd file consist of train_subset00,train_subset01,... test_subset00,test_subset01,..
:param mhd_root: 'train' or 'test',or 'val', to construct train_subset00,test_subset01,val_subset02...
:return: all mhd file list
"""
mhd_files = []
for root, _, files in os.walk(mhd_root):
for file in files:
if file.lower().endswith('.mhd'):
real_file = os.path.join(mhd_root, root, file)
mhd_files.append(real_file)
return mhd_files
def read_csv_to_pandas(mhd_info,col_sepator ='\t'):
"""
read csv information into pandas dataframe
:param mhd_info: csv file of mhd file
:param col_sepator: sepator string of columns
:return:
"""
with open(mhd_info, 'r') as csv:
head = csv.readline().split(",") # get header of csv
indexs = []
lines = csv.readlines()
list = []
for line in lines:
list.append(line.split(col_sepator))
indexs.append(line.split(col_sepator)[0]) #the first element should be id of patient
df = pd.DataFrame(data=list, columns=head,index=indexs)
return df
def extract_image_from_mhd(mhd_file_path,png_save_path_root =None):
"""
extract image from mhd file and return mhd information
:param mhd_file_path: mhd file to extract
:param png_save_path_root: file path where to save the extracted image (both image and mask image will be saved)
,if this param is None means only mhd information returns,no image extracted
:return:
"""
mhd_info = []
patient_id = ntpath.basename(mhd_file_path).replace(".mhd", "")
print("Patient: ", patient_id)
mhd_info.append(patient_id)
if not os.path.exists(png_save_path_root):
os.mkdir(png_save_path_root)
dst_dir = png_save_path_root+'/' + patient_id + "/"
if not os.path.exists(dst_dir):
os.mkdir(dst_dir)
itk_img = SimpleITK.ReadImage(mhd_file_path)
img_array = SimpleITK.GetArrayFromImage(itk_img)
print("Img array: ", img_array.shape)
(shape0,shape1,shape2) = img_array.shape
mhd_info.append(str(shape2))
mhd_info.append(str(shape1))
mhd_info.append(str(shape0))
origin = np.array(itk_img.GetOrigin()) # x,y,z Origin in world coordinates (mm)
print("Origin (x,y,z): ", origin)
mhd_info.append(str(origin[0]))
mhd_info.append(str(origin[1]))
mhd_info.append(str(origin[2]))
direction = np.array(itk_img.GetDirection()) # x,y,z Origin in world coordinates (mm)
print("Direction: ", direction)
direct_arow = [1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0]
if direction.tolist() == direct_arow:
print("positive direction..")
mhd_info.append(str(1))
else:
mhd_info.append(str(-1))
spacing = np.array(itk_img.GetSpacing()) # spacing of voxels in world coor. (mm)
print("Spacing (x,y,z): ", spacing)
mhd_info.append(str(spacing[0]))
mhd_info.append(str(spacing[1]))
mhd_info.append(str(spacing[2]))
rescale = spacing /TARGET_VOXEL_MM
print("Rescale: ", rescale)
mhd_info.append(str(rescale[0]))
mhd_info.append(str(rescale[1]))
mhd_info.append(str(rescale[2]))
if png_save_path_root is None: # get mhd information only
return mhd_info
if not os.path.exists(dst_dir):
if img_array.shape[1]== 512:
img_array = image_util.rescale_patient_images(img_array, spacing, TARGET_VOXEL_MM)
img_list = []
for i in range(img_array.shape[0]):
img = img_array[i]
seg_img, mask = get_segmented_lungs(img.copy())
img_list.append(seg_img)
img = normalize_hu_values(img)
cv2.imwrite(dst_dir + "img_" + str(i).rjust(4, '0') + "_i.png", img * 255)
cv2.imwrite(dst_dir + "img_" + str(i).rjust(4, '0') + "_m.png", mask * 255)
return mhd_info
================================================
FILE: util/progress_watch.py
================================================
import datetime
class Stopwatch(object):
def start(self):
self.start_time = Stopwatch.get_time()
def get_elapsed_time(self):
current_time = Stopwatch.get_time()
res = current_time - self.start_time
return res
def get_elapsed_seconds(self):
elapsed_time = self.get_elapsed_time()
res = elapsed_time.total_seconds()
return res
@staticmethod
def get_time():
res = datetime.datetime.now()
return res
@staticmethod
def start_new():
res = Stopwatch()
res.start()
return res
================================================
FILE: util/seg_util.py
================================================
import numpy as np
import matplotlib.pyplot as plt
plt.rcParams['font.family'] = ['SimHei'] # 设置字体为黑体
plt.rcParams['axes.unicode_minus'] = False # 正确显示负号
from scipy import ndimage as ndi
from skimage.filters import roberts
from skimage.measure import regionprops, label
from skimage.morphology import binary_closing, disk, binary_erosion
from skimage.segmentation import clear_border
def normalize_hu_values(image: np.ndarray, min_bound: int = -1000, max_bound: int = 400) -> np.ndarray:
"""
归一化HU值到[0,1]范围
Args:
image: 输入图像
min_bound: 最小HU值
max_bound: 最大HU值
Returns:
np.ndarray: 归一化后的图像
"""
image = (image - min_bound) / (max_bound - min_bound)
image[image > 1] = 1.
image[image < 0] = 0.
return image
def get_segmented_lungs(im, plot=False):
'''
extract lung ROI from pixel array
:param im: a patient's piexl array
:param plot: if plot when segment
:return:
'''
# Step 1: Convert into a binary image.
binary = im < -400
# Step 2: Remove the blobs connected to the border of the image.
cleared = clear_border(binary)
# Step 3: Label the image.
label_image = label(cleared)
# Step 4: Keep the labels with 2 largest areas.
areas = [r.area for r in regionprops(label_image)]
areas.sort()
if len(areas) > 2:
for region in regionprops(label_image):
if region.area < areas[-2]:
for coordinates in region.coords:
label_image[coordinates[0], coordinates[1]] = 0
binary = label_image > 0
# Step 5: Erosion operation with a disk of radius 2. This operation is seperate the lung nodules attached to the blood vessels.
selem = disk(2)
binary = binary_erosion(binary, selem)
# Step 6: Closure operation with a disk of radius 10. This operation is to keep nodules attached to the lung wall.
selem = disk(10) # CHANGE BACK TO 10
binary = binary_closing(binary, selem)
# Step 7: Fill in the small holes inside the binary mask of lungs.
edges = roberts(binary)
binary = ndi.binary_fill_holes(edges)
# Step 8: Superimpose the binary mask on the input image.
get_high_vals = binary == 0
im[get_high_vals] = -2000
if plot:
plt.figure(figsize=(10, 10))
plt.subplot(1, 2, 1)
plt.imshow(binary, cmap='gray')
plt.title('Lung Mask')
plt.subplot(1, 2, 2)
plt.imshow(im, cmap='gray')
plt.title('Masked Image')
plt.show()
return im, binary