45. self.signal = scale(signal).astype('float32')
46. if self.signal is None:
47. raise Exception("No MLII LEAD")
48.
49. self.peaks = find_peaks(self.signal, distance=180)[0]
50. mask_left = (self.peaks - self.mode // 2) > 0
51. mask_right = (self.peaks + self.mode // 2) < len(self.signal)
52. mask = mask_left & mask_right
53. self.peaks = self.peaks[mask]
54.
55. def __getitem__(self, index):
56. peak = self.peaks[index]
57. left, right = peak - self.mode // 2, peak + self.mode // 2
58.
59. img = self.signal[left:right]
60. img = img.reshape(1, img.shape[0])
61.
62. return {
63. "image": img,
64. "peak": peak
65. }
66.
67. def get_dataloader(self, num_workers=4, batch_size=16, shuffle=True):
68. data_loader = DataLoader(self, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)
69. return data_loader
70.
71. def __len__(self):
72. return len(self.peaks)
1. import json
2.
3. import cv2
4. from albumentations import Normalize, Compose
5. from albumentations.pytorch.transforms import ToTensorV2
6. from torch.utils.data import Dataset, DataLoader
7.
8. augment = Compose([
9. Normalize(),
10. ToTensorV2()
11. ])
12.
13.
14. class EcgDataset2D(Dataset):
15. def __init__(self, ann_path, mapping_path):
16. super().__init__()
17. self.data = json.load(open(ann_path))
18. self.mapper = json.load(open(mapping_path))
19.
20. def __getitem__(self, index):
21. img = cv2.imread(self.data[index]['path'])
22. img = augment(**{"image": img})['image']
23.
24. return {
25. "image": img,
26. "class": self.mapper[self.data[index]['label']]
27. }
28.
29.
30. def get_dataloader(self, num_workers=4, batch_size=16, shuffle=True):
31. data_loader = DataLoader(self, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)
32. return data_loader
33.
34. def __len__(self):
35.
36. return len(self.data)
Приложение 7
1. import os
2. import os.path as osp
3. from datetime import datetime
4.
5. import numpy as np
6. import torch
7. from torch import optim, nn
8. from torch.utils.tensorboard import SummaryWriter
9. from tqdm import tqdm
10.
11. from utils.network_utils import load_checkpoint, save_checkpoint
12.
13.
14. class BaseTrainer:
15. def __init__(self, config):
16. self.config = config
17. self.exp_name = self.config.get('exp_name', None)
18. if self.exp_name is None:
19. self.exp_name = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
20.
21. self.log_dir = osp.join(self.config['exp_dir'], self.exp_name, 'logs')
22. self.pth_dir = osp.join(self.config['exp_dir'], self.exp_name, 'checkpoints')
23. os.makedirs(self.log_dir, exist_ok=True)
24. os.makedirs(self.pth_dir, exist_ok=True)
25.
26. self.writer = SummaryWriter(log_dir=self.log_dir)
27.
28. self.model = self._init_net()
29. self.optimizer = self._init_optimizer()
30. self.criterion = nn.CrossEntropyLoss().to(self.config['device'])
31.
32. self.train_loader, self.val_loader = self._init_dataloaders()
33.
34. pretrained_path = self.config.get('model_path', False)
35. if pretrained_path:
36. self.training_epoch, self.total_iter = load_checkpoint(pretrained_path, self.model,
37. optimizer=self.optimizer)
38.
39. else:
40. self.training_epoch = 0
41. self.total_iter = 0
42.
43. self.epochs = self.config.get('epochs', int(1e5))
44.
45. def _init_net(self):
46. raise NotImplemented
47.
48. def _init_dataloaders(self):
49. raise NotImplemented
50.
51. def _init_optimizer(self):
52. optimizer = getattr(optim, self.config['optim'])(self.model.parameters(), **self.config['optim_params'])
53. return optimizer
54.
55. def train_epoch(self):
56. self.model.train()
57. total_loss = 0
58.
59. gt_class = np.empty(0)
60. pd_class = np.empty(0)
61.
62. for i, batch in enumerate(self.train_loader):
63. inputs = batch['image'].to(self.config['device'])
64. targets = batch['class'].to(self.config['device'])
65.
66. predictions = self.model(inputs)
67. loss = self.criterion(predictions, targets)
68.
69. classes = predictions.topk(k=1)[1].view(-1).cpu().numpy()
70.
71. gt_class = np.concatenate((gt_class, batch['class'].numpy()))
72. pd_class = np.concatenate((pd_class, classes))
73.
74. total_loss += loss.item()
75.
76. self.optimizer.zero_grad()
77. loss.backward()
78. self.optimizer.step()
79.
80. if (i + 1) % 10 == 0:
81. print("\tIter [%d/%d] Loss: %.4f" % (i + 1, len(self.train_loader), loss.item()))
82.
83. self.writer.add_scalar("Train loss (iterations)", loss.item(), self.total_iter)
84. self.total_iter += 1
85.
86. total_loss /= len(self.train_loader)
87. class_accuracy = sum(pd_class == gt_class) / pd_class.shape[0]
88.
89. print('Train loss - {:4f}'.format(total_loss))
90. print('Train CLASS accuracy - {:4f}'.format(class_accuracy))
91.
92. self.writer.add_scalar('Train loss (epochs)', total_loss, self.training_epoch)
93. self.writer.add_scalar('Train CLASS accuracy', class_accuracy, self.training_epoch)
94.
95. def val(self):
96. self.model.eval()
97. total_loss = 0
98.
99. gt_class = np.empty(0)
100. pd_class = np.empty(0)
101.
102. with torch.no_grad():
103. for i, batch in tqdm(enumerate(self.val_loader)):
104. inputs = batch['image'].to(self.config['device'])