Дипломная работа: Диагностика заболеваний по ЭКГ с помощью сверточных нейронных сетей

Внимание! Если размещение файла нарушает Ваши авторские права, то обязательно сообщите нам

139.         self.layer5 = self._make_layer(block, 1024, layers[8], stride=2,

140.                                        dilate=replace_stride_with_dilation[2])

141.         self.avgpool = nn.AdaptiveAvgPool1d(1)

142.         self.fc = nn.Linear(1024 * block.expansion, num_classes)

143.  

144.         for m in self.modules():

145.             if isinstance(m, nn.Conv1d):

146.                 nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')

147.             elif isinstance(m, (nn.BatchNorm1d, nn.GroupNorm)):

148.                 nn.init.constant_(m.weight, 1)

149.                 nn.init.constant_(m.bias, 0)

150.  

151.         # Zero-initialize the last BN in each residual branch,

152.         # so that the residual branch starts with zeros, and each residual block behaves like an identity.

153.         # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677

154.         if zero_init_residual:

155.             for m in self.modules():

156.                 if isinstance(m, BasicBlockHeartNet):

157.                     nn.init.constant_(m.bn2.weight, 0)

158.  

159.     def _make_layer(self, block, planes, blocks, stride=1, dilate=False):

160.         norm_layer = self._norm_layer

161.         downsample = None

162.         previous_dilation = self.dilation

163.         self.stride = stride

164.         if dilate:

165.             self.dilation *= stride

166.             stride = 1

167.         if stride != 1 or self.inplanes != planes * block.expansion:

168.             downsample = nn.Sequential(

169.                 conv_subsumpling(self.inplanes, planes * block.expansion)

170.             )

171.  

172.         layers = []

173.         layers.append(block(self.inplanes, planes, stride, downsample, self.groups,

174.                             self.base_width, previous_dilation, norm_layer))

175.         self.inplanes = planes * block.expansion

176.         for _ in range(1, blocks):

177.             layers.append(block(self.inplanes, planes, groups=self.groups,

178.                                 base_width=self.base_width, dilation=self.dilation,

179.                                 norm_layer=norm_layer))

180.  

181.         return nn.Sequential(*layers)

182.  

183.     def forward(self, x):

184.         x = self.conv1(x)

185.  

186.         x = self.layer0(x)

187.         x = self.layer1(x)

188.         x = self.layer2(x)

189.         x = self.layer2_(x)

190.         x = self.layer3(x)

191.         x = self.layer3_(x)

192.         x = self.layer4(x)

193.         x = self.layer4_(x)

194.         x = self.layer5(x)

195.  

196.         x = self.avgpool(x)

197.         x = x.reshape(x.size(0), -1)

198.         x = self.fc(x)

199.  

200.         return x

201.  

202.  

203. class EcgResNet34(nn.Module):

204.  

205.     def __init__(self, layers=(1, 5, 5, 5), num_classes=1000, zero_init_residual=False,

206.                  groups=1, width_per_group=64, replace_stride_with_dilation=None,

207.                  norm_layer=None, block=BasicBlock):

208.  

209.         super(EcgResNet34, self).__init__()

210.         if norm_layer is None:

211.             norm_layer = nn.BatchNorm1d

212.         self._norm_layer = norm_layer

213.  

214.         self.inplanes = 32

215.         self.dilation = 1

216.         if replace_stride_with_dilation is None:

217.             # each element in the tuple indicates if we should replace

218.             # the 2x2 stride with a dilated convolution instead

219.             replace_stride_with_dilation = [False, False, False]

220.         if len(replace_stride_with_dilation) != 3:

221.             raise ValueError("replace_stride_with_dilation should be None "

222.                              "or a 3-element tuple, got {}".format(replace_stride_with_dilation))

223.         self.groups = groups

224.         self.base_width = width_per_group

225.         self.conv1 = conv_block(1, self.inplanes, stride=1,)

226.         self.bn1 = norm_layer(self.inplanes)

227.         self.relu = nn.ReLU(inplace=True)

228.         self.layer1 = self._make_layer(block, 64, layers[0])

229.         self.layer2 = self._make_layer(block, 128, layers[1], stride=2,

230.                                        dilate=replace_stride_with_dilation[0])

231.         self.layer3 = self._make_layer(block, 256, layers[2], stride=2,

232.                                        dilate=replace_stride_with_dilation[1])

233.         self.layer4 = self._make_layer(block, 512, layers[3], stride=2,

234.                                        dilate=replace_stride_with_dilation[2])

235.         self.avgpool = nn.AdaptiveAvgPool1d(1)

236.         self.fc = nn.Linear(512 * block.expansion, num_classes)

237.  

238.         for m in self.modules():

239.             if isinstance(m, nn.Conv1d):

240.                 nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')

241.             elif isinstance(m, (nn.BatchNorm1d, nn.GroupNorm)):

242.                 nn.init.constant_(m.weight, 1)

243.                 nn.init.constant_(m.bias, 0)

244.  

245.         # Zero-initialize the last BN in each residual branch,

246.         # so that the residual branch starts with zeros, and each residual block behaves like an identity.

247.         # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677

248.         if zero_init_residual:

249.             for m in self.modules():

250.                 if isinstance(m, BasicBlock):

251.                     nn.init.constant_(m.bn2.weight, 0)

252.  

253.     def _make_layer(self, block, planes, blocks, stride=1, dilate=False):

254.         norm_layer = self._norm_layer

255.         downsample = None

256.         previous_dilation = self.dilation

257.         if dilate:

258.             self.dilation *= stride

259.             stride = 1

260.         if stride != 1 or self.inplanes != planes * block.expansion:

261.             downsample = nn.Sequential(

262.                 conv_subsumpling(self.inplanes, planes * block.expansion, stride),

263.                 norm_layer(planes * block.expansion),

264.             )

265.  

266.         layers = []

267.         layers.append(block(self.inplanes, planes, stride, downsample, self.groups,

268.                             self.base_width, previous_dilation, norm_layer))

269.         self.inplanes = planes * block.expansion

270.         for _ in range(1, blocks):

271.             layers.append(block(self.inplanes, planes, groups=self.groups,

272.                                 base_width=self.base_width, dilation=self.dilation,

273.                                 norm_layer=norm_layer))

274.  

275.         return nn.Sequential(*layers)

276.  

277.     def forward(self, x):

278.         x = self.conv1(x)

279.  

280.         x = self.layer1(x)

281.         x = self.layer2(x)

282.         x = self.layer3(x)

283.         x = self.layer4(x)

284.  

285.         x = self.avgpool(x)

286.         x = x.reshape(x.size(0), -1)

287.         x = self.fc(x)

288.  

289.         return x

290.  

291.  

292. class HeartNetIEEE(nn.Module):

293.     def __init__(self, num_classes=8):

294.         super().__init__()

295.  

296.         self.features = nn.Sequential(

297.             nn.Conv1d(1, 64, kernel_size=5),

298.             nn.ReLU(inplace=True),

299.             nn.Conv1d(64, 64, kernel_size=5),

300.             nn.ReLU(inplace=True),

301.             nn.MaxPool1d(2),

302.             nn.Conv1d(64, 128, kernel_size=3),

303.             nn.ReLU(inplace=True),

304.             nn.Conv1d(128, 128, kernel_size=3),

305.             nn.ReLU(inplace=True),

306.             nn.MaxPool1d(2)

307.         )

308.  

309.         self.classifier = nn.Sequential(

310.             nn.Linear(128 * 28, 256),

311.             nn.Linear(256, 128),

312.             nn.Linear(128, num_classes)

313.         )

314.  

315.     def forward(self, x):

316.         x = self.features(x)

317.         x = x.view(x.size(0), 128 * 28)

318.         x = self.classifier(x)

319.         return x

320.  

321.  

322.  

323. class Flatten(nn.Module):

324.     def forward(self, input):

325.         return input.view(input.size(0), -1)

326.  

327.  

328. class ZolotyhNet(nn.Module):

329.     def __init__(self, num_classes=8):

330.         super().__init__()

331.  

332.         self.features_up = nn.Sequential(

333.             nn.Conv1d(1, 8, kernel_size=3, padding=1),

334.             nn.BatchNorm1d(8),

335.             nn.ReLU(inplace=True),

336.             nn.MaxPool1d(2),

337.  

338.             nn.Conv1d(8, 16, kernel_size=3, padding=1),

339.             nn.BatchNorm1d(16),

340.             nn.ReLU(inplace=True),

341.             nn.MaxPool1d(2),

342.  

343.             nn.Conv1d(16, 32, kernel_size=3, padding=1),

344.             nn.BatchNorm1d(32),

345.             nn.ReLU(inplace=True),

346.             nn.MaxPool1d(2),

347.  

348.             nn.Conv1d(32, 32, kernel_size=3, padding=1),

349.             nn.BatchNorm1d(32),

350.             nn.ReLU(inplace=True),

351.             nn.MaxPool1d(2),

352.  

353.             nn.Conv1d(32, 1, kernel_size=3, padding=1),

354.             Flatten(),

355.         )

356.  

357.         self.features_down = nn.Sequential(

358.             Flatten(),

359.             nn.Linear(128,64),

360.             nn.BatchNorm1d(64),

361.             nn.ReLU(inplace=True),

362.  

363.             nn.Linear(64, 16),

364.             nn.BatchNorm1d(16),

365.             nn.ReLU(inplace=True),

366.  

367.             nn.Linear(16, 8)

368.         )

369.  

370.         self.classifier = nn.Linear(8, num_classes)

371.  

372.     def forward(self, x):

373.         out_up = self.features_up(x)

374.         out_down = self.features_down(x)

375.         out_middle = out_up + out_down

376.  

377.         out = self.classifier(out_middle)

378.  

379.         return out

Приложение 5

1. import json

2. import os

3. import os.path as osp

4. import numpy as np

5. from datetime import datetime

6.  

7. import torch

8. import wfdb

9. from tqdm import tqdm

10. import plotly.graph_objects as go

11.  

12. from utils.network_utils import load_checkpoint

13.  

14.  

15. class BasePipeline:

16.     def __init__(self, config):

17.         self.config = config

18.         self.exp_name = self.config.get('exp_name', None)

19.         if self.exp_name is None:

20.             self.exp_name = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")

21.  

22.         self.res_dir = osp.join(self.config['exp_dir'], self.exp_name, 'results')

23.         os.makedirs(self.res_dir, exist_ok=True)

24.  

25.         self.model = self._init_net()

26.  

27.         self.pipeline_loader = self._init_dataloader()

28.  

29.         self.mapper = json.load(open(config['mapping_json']))

30.         self.mapper = {j: i for i, j in self.mapper.items()}

31.  

32.         pretrained_path = self.config.get('model_path', False)

33.         if pretrained_path:

34.             load_checkpoint(pretrained_path, self.model)

35.         else:

36.             raise Exception("model_path doesnt't exist in config. Please specify checkpoint path")

37.  

38.     def _init_net(self):

39.         raise NotImplemented

40.  

41.     def _init_dataloader(self):

42.         raise NotImplemented

43.  

44.     def run_pipeline(self):

45.         self.model.eval()

46.         pd_class = np.empty(0)

47.         pd_peaks = np.empty(0)

48.  

49.         with torch.no_grad():

50.             for i, batch in tqdm(enumerate(self.pipeline_loader)):

51.                 inputs = batch['image'].to(self.config['device'])

52.  

53.                 predictions = self.model(inputs)

54.  

55.                 classes = predictions.topk(k=1)[1].view(-1).cpu().numpy()

56.  

57.                 pd_class = np.concatenate((pd_class, classes))

58.                 pd_peaks = np.concatenate((pd_peaks, batch['peak']))

59.  

60.         pd_class = pd_class.astype(int)

61.         pd_peaks = pd_peaks.astype(int)

62.  

63.         annotations = []

64.         for label, peak in zip(pd_class, pd_peaks):

65.             if peak < len(self.pipeline_loader.dataset.signal) and self.mapper[label] != 'N':

66.                 annotations.append({

67.                     "x": peak,

68.                     "y": self.pipeline_loader.dataset.signal[peak],

69.                     "text": self.mapper[label],

70.                     "xref": "x",

71.                     "yref": "y",

72.                     "showarrow": True,

73.                     "arrowcolor": "black",

74.                     "arrowhead": 1,

75.                     "arrowsize": 2

76.                 })

77.  

78.         if osp.exists(self.config['ecg_data'] + '.atr'):

79.             ann = wfdb.rdann(self.config['ecg_data'], extension='atr')

80.             for label, peak in zip(ann.symbol, ann.sample):

81.                 if peak < len(self.pipeline_loader.dataset.signal) and label != 'N':

82.                     annotations.append({

83.                         "x": peak,

84.                         "y": self.pipeline_loader.dataset.signal[peak] - 0.1,

85.                         "text": label,

86.                         "xref": "x",

87.                         "yref": "y",

88.                         "showarrow": False,

89.                         "bordercolor": "#c7c7c7",

90.                         "borderwidth": 1,

91.                         "borderpad": 4,

92.                         "bgcolor": "#ffffff",

93.                         "opacity": 1

94.                     })

95.  

96.         fig = go.Figure(data=go.Scatter(x=list(range(len(self.pipeline_loader.dataset.signal))), y=self.pipeline_loader.dataset.signal))

97.         fig.update_layout(title='ECG',

98.                           xaxis_title='Time',

99.                           yaxis_title='ECG Output Value',

100.                           title_x=0.5,

101.                           annotations=annotations,

102.                           autosize=True)

103.  

104.         fig.write_html(osp.join(self.res_dir, osp.basename(self.config['ecg_data'] + '.html')))

Приложение 6

1. import json

2.  

3. import wfdb

4. from scipy.signal import find_peaks

5. from sklearn.preprocessing import scale

6. from torch.utils.data import Dataset, DataLoader

7. import numpy as np

8.  

9.  

10. class EcgDataset1D(Dataset):

11.     def __init__(self, ann_path, mapping_path):

12.         super().__init__()

13.         self.data = json.load(open(ann_path))

14.         self.mapper = json.load(open(mapping_path))

15.  

16.     def __getitem__(self, index):

17.         img = np.load(self.data[index]['path']).astype('float32')

18.         img = img.reshape(1, img.shape[0])

19.  

20.         return {

21.             "image": img,

22.             "class": self.mapper[self.data[index]['label']]

23.         }

24.  

25.     def get_dataloader(self, num_workers=4, batch_size=16, shuffle=True):

26.         data_loader = DataLoader(self, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)

27.         return data_loader

28.  

29.     def __len__(self):

30.         return len(self.data)

31.  

32.  

33. def callback_get_label(dataset, idx):

34.     return dataset[idx]["class"]

35.  

36.  

37. class EcgPipelineDataset1D(Dataset):

38.     def __init__(self, path, mode=128):

39.         super().__init__()

40.         record = wfdb.rdrecord(path)

41.         self.signal = None

42.         self.mode = mode

43.         for sig_name, signal in zip(record.sig_name, record.p_signal.T):

44.             if sig_name in ['MLII', 'II'] and np.all(np.isfinite(signal)):