[1]:
from src import *

from tqdm import tqdm

import torch
import torch.nn as nn

from torch import optim

from scipy.signal import find_peaks
from scipy import stats
[2]:
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

Model Implementation PULSEGAN

[4]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        # encoding layers
        self.conv1 = nn.Conv1d(1, 16, 3, stride=1)
        self.conv2 = nn.Conv1d(16, 32, 3, stride=1)
        self.conv3 = nn.Conv1d(32, 64, 3, stride=1)
        self.conv4 = nn.Conv1d(64, 96, 3, stride=1)
        self.conv5 = nn.Conv1d(96, 128, 3, stride=1)
        self.conv6 = nn.Conv1d(128, 156, 3, stride=1)


        # decoding layers
        self.deconv6 = nn.ConvTranspose1d(156, 128, kernel_size=3, stride=1)
        self.deconv5 = nn.ConvTranspose1d(128, 96, kernel_size=3, stride=1)
        self.deconv4 = nn.ConvTranspose1d(96, 64, kernel_size=3, stride=1)
        self.deconv3 = nn.ConvTranspose1d(64, 32, kernel_size=3, stride=1)
        self.deconv2 = nn.ConvTranspose1d(32, 16, kernel_size=3, stride=1)
        self.deconv1 = nn.ConvTranspose1d(16, 1, kernel_size=3, stride=1)

        # activations
        self.activ1 = nn.PReLU() #not inplace, I want to copy
        self.activ2 = nn.Tanh()


    def forward(self, x):

        # encoder
        x = self.conv1(x)
        res1 = self.activ1(x)

        x = self.conv2(res1)
        res2 = self.activ1(x)

        x = self.conv3(res2)
        res3 = self.activ1(x)

        x = self.conv4(res3)
        res4 = self.activ1(x)

        x = self.conv5(res4)
        res5 = self.activ1(x)

        x = self.conv6(res5)
        x = self.activ1(x)



        # decoder
        x = self.deconv6(x)
        x = self.activ1(x)
        x += res5

        x = self.deconv5(x)
        x = self.activ1(x)
        x += res4

        x = self.deconv4(x)
        x = self.activ1(x)
        x += res3

        x = self.deconv3(x)
        x = self.activ1(x)
        x += res2

        x = self.deconv2(x)
        x = self.activ1(x)
        x += res1

        x = self.deconv1(x)
        x = self.activ2(x)


        return x




class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        # encoding layers
        self.encode = nn.Sequential(
            nn.Conv1d(2, 8, 3, stride=1),
            nn.BatchNorm1d(8),
            nn.LeakyReLU(0.1),
            nn.Conv1d(8, 12, 3, stride=1),
            nn.BatchNorm1d(12),
            nn.LeakyReLU(0.1),
            nn.Conv1d(12, 16, 3, stride=1),
            nn.BatchNorm1d(16),
            nn.LeakyReLU(0.1),
            nn.Conv1d(16, 20, 3, stride=1),
            nn.BatchNorm1d(20),
            nn.LeakyReLU(0.1),
            nn.Conv1d(20, 24, 3, stride=1),
            nn.BatchNorm1d(24),
            nn.LeakyReLU(0.1),

        )

        self.embed = nn.Sequential(
            nn.Conv1d(24, 6, 3, stride=1),
            nn.LeakyReLU(0.1),
        )

        self.head = nn.Sequential(
            nn.Linear(696,128),
            nn.ReLU(),
            nn.Linear(128,24),
            nn.ReLU(),
            nn.Linear(24,1),
            nn.Sigmoid()
        )

    def forward(self, X, Xc):
        # concatenate both ground truth or predicted + original signal (used as conditional)
        b, c, l = X.shape

        Xi = torch.cat([X,Xc],axis=1)

        Xi = self.encode(Xi)
        Xi = self.embed(Xi)

        Xi = Xi.reshape(b,-1)
        proba = self.head(Xi)

        return proba


class PulseGan(nn.Module):

    def __init__(self):
        super(PulseGan,self).__init__()

        self.G = Generator()
        self.D = Discriminator()


    def forward(self, X, Xc):
        Xrppg = self.G(X)

        # Input of original rppg and Ground Truth
        proba_real = self.D(X,Xc)

        # Input of original rppg and generator output
        proba_fake = self.D(X,Xrppg)

        return X,Xc,Xrppg,proba_fake, proba_real
[5]:
class Gloss(nn.Module):
    def __init__(self):
        super(Gloss,self).__init__()
        return

    def forward(self, rppg_gt, rppg_pred, proba_pred, lbd=.1, beta=.2):
        # accumulating loss for all sample
        loss = 0
        b, c, l = rppg_gt.shape

        for i in range(b):
            d = proba_pred[i]
            Xc = rppg_gt[i]
            Xg = rppg_pred[i]



            # loss += 1/2 * (d - 1)**2

            # loss += lbd * torch.sum(torch.abs(Xc-Xg))

            #loss += beta * torch.sum(self.fft_loss(Xc,Xg))

            loss += beta * self.peak_loss(Xc,Xg)

        return loss


    def fft_loss(self,Xc,Xg):
        return torch.abs(torch.fft.fft(Xc) - torch.fft.fft(Xg))

    def peak_loss(self,Xc,Xg):
        idx_device = Xc.device.index
        d = torch.device("cuda:{}".format(idx_device))

        rppgC = Xc.cpu().flatten()
        rppgG = Xg.cpu().flatten().detach()
        peaksC, _ = find_peaks(rppgC, height=0.5) #arbitrary height simply because normalized rppg
        peaksG, _ = find_peaks(rppgG, height=0.5)

        dist = self.min_dist(peaksC, peaksG)
        return dist

    def min_dist(self,I1,I2):
        # NOTE: NOT COMMUTATIVE
        # I1 should be ground truth
        ret = 0
        if len(I2) == 0:
            I2 = np.zeros_like(I1)
        for val in I1:
            diff = np.min(np.abs(I2-val))
            ret += (diff) # penalize way more large differences
        return ret


class Dloss(nn.Module):
    def __init__(self):
        super(Dloss,self).__init__()
        return

    def forward(self, proba_fake, proba_real):
        # accumulating loss for all sample
        b = proba_fake.shape[0]
        loss = 0

        for i in range(b):
            df = proba_fake[i]
            dr = proba_real[i]

            loss += 1/2 * df ** 2 + 1/2 * (dr-1)**2

        return loss

DataLoader

[6]:
# # TO AUGMENT SPECIFIC MOTION DATA
# eg = pd.read_csv('../experiments/csvs/phys_pred2.csv')

# intervals = {
#     'exp41':[700, 829],
#     'exp46':[1620, 1800],
#     'exp49':[1700, 1900]
# }


# ret = {
#     'id': [],
#     'pred': [],
#     'gt': []
# }

# for exp in intervals.keys():
#     a,b = intervals[exp]
#     roi = eg[eg.id==exp].reset_index(drop=True)
#     roi = roi.iloc[a:b]

#     preds = list(roi['pred'])
#     gts = list(roi['gt'])


#     for i in range(100):
#         ret['id'] += [exp+'_'+str(i)] * len(roi)
#         ret['pred'] += preds
#         ret['gt'] += gts

# # new = pd.concat([eg,pd.DataFrame.from_dict(ret)])
# new = pd.DataFrame.from_dict(ret)
# new.to_csv('motionpulse_gan.csv',index=False)
[7]:
class PulseDataset(Dataset):
    """Dataset of rPPGs"""

    def __init__(self, csv_file, D = 128):
        """
        Args:
            csv_file (string): Path to the csv file with rpp predictions and ground truths
        """
        self.info = pd.read_csv(csv_file)
        self.D = D

        exps = self.info.id.unique()
        self.num_samples = sum([len(self.info.loc[self.info.id == exp])//self.D for exp in exps])

        # normalize values
        self.info['pred'] = normalize(self.info['pred'])
        self.info['gt'] = normalize(self.info['gt'])

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        et = np.array(self.info['pred'].iloc[idx*self.D:(idx+1)*self.D])
        gt = np.array(self.info['gt'].iloc[idx*self.D:(idx+1)*self.D])
        et = et.reshape(1,-1)
        gt = gt.reshape(1,-1)

        # cast to tensor
        et = torch.from_numpy(et).type(torch.FloatTensor)
        gt = torch.from_numpy(gt).type(torch.FloatTensor)

        return et,gt
[8]:
train_dataset = PulseDataset('../experiments/csvs/phys_pred2.csv')
trainloader = torch.utils.data.DataLoader(train_dataset, batch_size=1, shuffle=True)

train, test = torch.utils.data.random_split(trainloader.dataset, [len(trainloader.dataset)//2,
                                                                 len(trainloader.dataset)//2])

trainloader = torch.utils.data.DataLoader(train, batch_size=3, shuffle=False)
testloader = torch.utils.data.DataLoader(test, batch_size=3, shuffle=False)

dataloaders = {
    'train': trainloader,
    'test': testloader
}

Model Training

[24]:
model = PulseGan()
model = model.to(device)

# PATH = './checkpoints/PulseGan2_49.pt'
# model.load_state_dict(torch.load(PATH))

criterion_G = Gloss()
criterion_D = Dloss()

optimizer = optim.Adam(model.parameters(), lr=0.0001)


val_loss_history = []
train_loss_history = []

num_epochs = 50

for epoch in tqdm(range(num_epochs)):
    print('Epoch {}/{}'.format(epoch, num_epochs - 1))
    print('-' * 10)

    # Each epoch has a training and validation phase
    phases = ['train', 'test']
    for phase in phases:
        running_loss = 0.0
        if phase == 'train':
            model.train()  # Set model to training mode -> activate droput layers and batch norm
        else:
            model.eval()  # Set model to evaluate mode

        # Iterate over data.
        for inputs, targets in dataloaders[phase]:
            inputs = inputs.to(device)
            targets = targets.to(device)

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward
            # track history if only in train
            with torch.set_grad_enabled(phase == 'train'):
                X,Xc,Xrppg, proba_fake, proba_real = model(inputs,targets)

                loss = criterion_G(Xc, Xrppg, proba_fake) + criterion_D(proba_fake, proba_real)
                # backward + optimize only if in training phase
                if phase == 'train':
                    loss.backward()
                    print(torch.sum(model.G.conv1.weight.grad))
                    optimizer.step()

            # statistics
            running_loss += loss.item()

        epoch_loss = running_loss / len(dataloaders[phase])
        print('{} Loss: {:.4f} '.format(phase, epoch_loss))

        if phase == 'test':
            val_loss_history.append(epoch_loss)
        else:
            train_loss_history.append(epoch_loss)

    torch.save(model.state_dict(), f'checkpoints/PulseGan3_{epoch}.pt')

  0%|                                                                                          | 0/50 [00:00<?, ?it/s]
Epoch 0/49
----------
tensor(0.0234, device='cuda:1')
tensor(-0.0743, device='cuda:1')
tensor(0.1206, device='cuda:1')
tensor(-0.0021, device='cuda:1')
tensor(-0.1285, device='cuda:1')
tensor(0.0246, device='cuda:1')
tensor(0.0382, device='cuda:1')
tensor(-0.0018, device='cuda:1')
tensor(-0.0075, device='cuda:1')
tensor(-0.0245, device='cuda:1')
tensor(0.0204, device='cuda:1')
tensor(-0.1357, device='cuda:1')
tensor(0.1360, device='cuda:1')
tensor(0.0730, device='cuda:1')
tensor(-0.0198, device='cuda:1')
tensor(-0.1040, device='cuda:1')
tensor(-0.0134, device='cuda:1')
tensor(0.0175, device='cuda:1')
tensor(0.0108, device='cuda:1')
tensor(-0.0965, device='cuda:1')
tensor(-0.1338, device='cuda:1')
tensor(0.0200, device='cuda:1')
tensor(0.0848, device='cuda:1')
tensor(-0.0622, device='cuda:1')
tensor(0.0418, device='cuda:1')
tensor(-0.1093, device='cuda:1')
tensor(-0.0962, device='cuda:1')
tensor(-0.0813, device='cuda:1')
tensor(0.0491, device='cuda:1')
tensor(0.0067, device='cuda:1')
tensor(0.0328, device='cuda:1')
tensor(-0.0428, device='cuda:1')
tensor(0.0088, device='cuda:1')
tensor(0.0259, device='cuda:1')
tensor(0.0136, device='cuda:1')
tensor(0.0074, device='cuda:1')
tensor(0.0354, device='cuda:1')
tensor(0.0062, device='cuda:1')
tensor(0.0257, device='cuda:1')
tensor(-0.1017, device='cuda:1')
tensor(0.1740, device='cuda:1')
tensor(0.1071, device='cuda:1')
tensor(0.0378, device='cuda:1')
tensor(-0.2264, device='cuda:1')
tensor(-0.0130, device='cuda:1')
tensor(-0.0042, device='cuda:1')
tensor(-0.0311, device='cuda:1')
tensor(0.0183, device='cuda:1')
tensor(0.1103, device='cuda:1')
tensor(0.0898, device='cuda:1')
tensor(0.0271, device='cuda:1')
tensor(0.0695, device='cuda:1')
tensor(-0.1701, device='cuda:1')
tensor(-0.0974, device='cuda:1')
tensor(0.0893, device='cuda:1')
tensor(-0.0764, device='cuda:1')
tensor(-0.0080, device='cuda:1')
tensor(-0.0192, device='cuda:1')
tensor(0.1005, device='cuda:1')
tensor(0.0511, device='cuda:1')
tensor(-0.1247, device='cuda:1')
tensor(-2.6003e-06, device='cuda:1')
tensor(0.0542, device='cuda:1')
tensor(0.0854, device='cuda:1')
tensor(0.1142, device='cuda:1')
tensor(-0.1003, device='cuda:1')
tensor(0.0157, device='cuda:1')
tensor(-0.0284, device='cuda:1')
tensor(0.0932, device='cuda:1')
tensor(0.0014, device='cuda:1')
tensor(0.0446, device='cuda:1')
tensor(-0.0111, device='cuda:1')
tensor(0.0552, device='cuda:1')
tensor(-0.1451, device='cuda:1')
tensor(0.0141, device='cuda:1')
tensor(-0.0522, device='cuda:1')
tensor(-0.0438, device='cuda:1')
tensor(-0.2118, device='cuda:1')
tensor(0.0863, device='cuda:1')
tensor(-0.0207, device='cuda:1')
tensor(-0.0482, device='cuda:1')
tensor(-0.0817, device='cuda:1')
tensor(-0.0443, device='cuda:1')
tensor(-0.1473, device='cuda:1')
tensor(-0.0237, device='cuda:1')
tensor(0.0896, device='cuda:1')
tensor(0.1545, device='cuda:1')
tensor(0.1527, device='cuda:1')
tensor(0.0985, device='cuda:1')
tensor(0.2924, device='cuda:1')
train Loss: 4.3812

  2%|█▋                                                                                | 1/50 [00:02<01:54,  2.34s/it]
test Loss: 4.3989
Epoch 1/49
----------
tensor(-0.0466, device='cuda:1')
tensor(0.0706, device='cuda:1')
tensor(-0.0634, device='cuda:1')
tensor(0.0787, device='cuda:1')
tensor(0.0988, device='cuda:1')
tensor(-0.1149, device='cuda:1')
tensor(-0.1962, device='cuda:1')
tensor(-0.0200, device='cuda:1')
tensor(-0.0168, device='cuda:1')
tensor(-0.1828, device='cuda:1')
tensor(-0.0935, device='cuda:1')
tensor(0.0397, device='cuda:1')
tensor(-0.1368, device='cuda:1')
tensor(0.2060, device='cuda:1')
tensor(-0.0857, device='cuda:1')
tensor(-0.2954, device='cuda:1')
tensor(-0.0978, device='cuda:1')
tensor(0.1871, device='cuda:1')
tensor(-0.3209, device='cuda:1')
tensor(0.3294, device='cuda:1')
tensor(0.1604, device='cuda:1')
tensor(-0.0780, device='cuda:1')
tensor(0.1306, device='cuda:1')
tensor(0.0174, device='cuda:1')
tensor(-0.0617, device='cuda:1')
tensor(0.2473, device='cuda:1')
tensor(0.0747, device='cuda:1')
tensor(0.1072, device='cuda:1')
tensor(0.2234, device='cuda:1')
tensor(-0.0426, device='cuda:1')
tensor(0.0030, device='cuda:1')
tensor(-0.1952, device='cuda:1')
tensor(0.0560, device='cuda:1')
tensor(0.2956, device='cuda:1')
tensor(-0.0517, device='cuda:1')
tensor(0.1840, device='cuda:1')
tensor(0.1116, device='cuda:1')
tensor(-0.0444, device='cuda:1')
tensor(-0.1981, device='cuda:1')
tensor(0.1247, device='cuda:1')
tensor(0.2024, device='cuda:1')
tensor(-0.0146, device='cuda:1')
tensor(0.0853, device='cuda:1')
tensor(-0.5585, device='cuda:1')
tensor(0.4518, device='cuda:1')
tensor(0.3298, device='cuda:1')
tensor(0.1341, device='cuda:1')
tensor(-0.0039, device='cuda:1')
tensor(0.2079, device='cuda:1')
tensor(-0.0561, device='cuda:1')
tensor(0.0712, device='cuda:1')
tensor(0.0156, device='cuda:1')
tensor(-1.7867, device='cuda:1')
tensor(-0.1428, device='cuda:1')
tensor(0.3574, device='cuda:1')
tensor(-0.6353, device='cuda:1')
tensor(0.2975, device='cuda:1')
tensor(0.4641, device='cuda:1')
tensor(0.4429, device='cuda:1')
tensor(0.1414, device='cuda:1')
tensor(0.5793, device='cuda:1')
tensor(0.1294, device='cuda:1')
tensor(-0.0399, device='cuda:1')
tensor(0.4883, device='cuda:1')
tensor(0.1969, device='cuda:1')
tensor(0.5092, device='cuda:1')
tensor(-0.0338, device='cuda:1')
tensor(-0.0100, device='cuda:1')
tensor(0.0408, device='cuda:1')
tensor(0.0938, device='cuda:1')
tensor(0.3074, device='cuda:1')
tensor(-0.3922, device='cuda:1')
tensor(0.1945, device='cuda:1')
tensor(-0.1447, device='cuda:1')
tensor(-0.7290, device='cuda:1')
tensor(0.0621, device='cuda:1')
tensor(-0.2128, device='cuda:1')
tensor(-0.3295, device='cuda:1')
tensor(-0.5787, device='cuda:1')
tensor(0.5662, device='cuda:1')
tensor(0.2572, device='cuda:1')
tensor(0.0487, device='cuda:1')
tensor(0.6624, device='cuda:1')
tensor(0.4487, device='cuda:1')
tensor(0.5060, device='cuda:1')
tensor(0.6728, device='cuda:1')
tensor(0.2226, device='cuda:1')
tensor(0.7603, device='cuda:1')
tensor(-0.4337, device='cuda:1')
tensor(1.0581, device='cuda:1')
train Loss: 4.3063

  4%|███▎                                                                              | 2/50 [00:04<01:50,  2.31s/it]
test Loss: 4.8428
Epoch 2/49
----------
tensor(-0.1518, device='cuda:1')
tensor(-0.4037, device='cuda:1')
tensor(0.0665, device='cuda:1')
tensor(0.1000, device='cuda:1')
tensor(-0.3869, device='cuda:1')
tensor(0.6785, device='cuda:1')
tensor(-1.8758, device='cuda:1')
tensor(1.4685, device='cuda:1')
tensor(0.2091, device='cuda:1')
tensor(-0.1320, device='cuda:1')
tensor(-0.1493, device='cuda:1')
tensor(0.1929, device='cuda:1')
tensor(0.1351, device='cuda:1')
tensor(1.1573, device='cuda:1')
tensor(0.2140, device='cuda:1')
tensor(0.2945, device='cuda:1')
tensor(0.0977, device='cuda:1')
tensor(0.2490, device='cuda:1')
tensor(0.9372, device='cuda:1')
tensor(0.4538, device='cuda:1')
tensor(-0.7219, device='cuda:1')
tensor(-0.2956, device='cuda:1')
tensor(-0.0590, device='cuda:1')
tensor(-1.4257, device='cuda:1')
tensor(-0.5703, device='cuda:1')
tensor(-0.0218, device='cuda:1')
tensor(-0.1934, device='cuda:1')
tensor(0.5188, device='cuda:1')
tensor(0.0345, device='cuda:1')
tensor(-0.3343, device='cuda:1')
tensor(1.2776, device='cuda:1')
tensor(1.3704, device='cuda:1')
tensor(0.9921, device='cuda:1')
tensor(0.3158, device='cuda:1')
tensor(0.1970, device='cuda:1')
tensor(0.1943, device='cuda:1')
tensor(1.0164, device='cuda:1')
tensor(-0.3316, device='cuda:1')
tensor(0.0655, device='cuda:1')
tensor(0.2339, device='cuda:1')
tensor(0.2097, device='cuda:1')
tensor(-0.6831, device='cuda:1')
tensor(-1.9790, device='cuda:1')
tensor(-0.1344, device='cuda:1')
tensor(0.3001, device='cuda:1')
tensor(0.0550, device='cuda:1')
tensor(0.1961, device='cuda:1')
tensor(0.0219, device='cuda:1')
tensor(-0.0503, device='cuda:1')
tensor(-0.5584, device='cuda:1')
tensor(0.0476, device='cuda:1')
tensor(0.0385, device='cuda:1')
tensor(0.3746, device='cuda:1')
tensor(0.1593, device='cuda:1')
tensor(-0.0227, device='cuda:1')
tensor(-0.8946, device='cuda:1')
tensor(0.2608, device='cuda:1')
tensor(0.2546, device='cuda:1')
tensor(0.2744, device='cuda:1')
tensor(0.1534, device='cuda:1')
tensor(-0.0073, device='cuda:1')
tensor(0.2475, device='cuda:1')
tensor(0.1863, device='cuda:1')
tensor(0.3482, device='cuda:1')
tensor(0.0027, device='cuda:1')
tensor(0.0592, device='cuda:1')
tensor(0.0772, device='cuda:1')
tensor(0.1337, device='cuda:1')
tensor(0.0555, device='cuda:1')
tensor(0.0469, device='cuda:1')
tensor(0.0385, device='cuda:1')
tensor(0.6353, device='cuda:1')
tensor(-0.0043, device='cuda:1')
tensor(0.1416, device='cuda:1')
tensor(-0.0288, device='cuda:1')
tensor(-0.1519, device='cuda:1')
tensor(-0.4411, device='cuda:1')
tensor(-1.0242, device='cuda:1')
tensor(-0.0882, device='cuda:1')
tensor(-0.0065, device='cuda:1')
tensor(-0.0340, device='cuda:1')
tensor(-2.1300, device='cuda:1')
tensor(0.0155, device='cuda:1')
tensor(0.0585, device='cuda:1')
tensor(0.3391, device='cuda:1')
tensor(0.1415, device='cuda:1')
tensor(0.2140, device='cuda:1')
tensor(0.3713, device='cuda:1')
tensor(0.1584, device='cuda:1')
tensor(0.1127, device='cuda:1')
train Loss: 4.6357

  6%|████▉                                                                             | 3/50 [00:06<01:48,  2.30s/it]
test Loss: 8.0453
Epoch 3/49
----------
tensor(0.1251, device='cuda:1')
tensor(0.5965, device='cuda:1')
tensor(0.8128, device='cuda:1')
tensor(0.0873, device='cuda:1')
tensor(0.0457, device='cuda:1')
tensor(0.0806, device='cuda:1')
tensor(-0.0122, device='cuda:1')
tensor(0.0488, device='cuda:1')
tensor(-1.0687, device='cuda:1')
tensor(0.0129, device='cuda:1')
tensor(0.0638, device='cuda:1')
tensor(0.0726, device='cuda:1')
tensor(0.0339, device='cuda:1')
tensor(0.0384, device='cuda:1')
tensor(0.0302, device='cuda:1')
tensor(0.0669, device='cuda:1')
tensor(-0.0645, device='cuda:1')
tensor(-0.0290, device='cuda:1')
tensor(0.2668, device='cuda:1')
tensor(0.0244, device='cuda:1')
tensor(0.0068, device='cuda:1')
tensor(0.0501, device='cuda:1')
tensor(0.0091, device='cuda:1')
tensor(0.0353, device='cuda:1')
tensor(0.1462, device='cuda:1')
tensor(0.0282, device='cuda:1')
tensor(0.0473, device='cuda:1')
tensor(0.0179, device='cuda:1')
tensor(-0.0067, device='cuda:1')
tensor(0.2164, device='cuda:1')
tensor(0.0529, device='cuda:1')
tensor(-0.0611, device='cuda:1')
tensor(0.0327, device='cuda:1')
tensor(-0.0041, device='cuda:1')
tensor(0.0084, device='cuda:1')
tensor(0.0319, device='cuda:1')
tensor(0.0379, device='cuda:1')
tensor(-0.0167, device='cuda:1')
tensor(-0.0464, device='cuda:1')
tensor(0.0054, device='cuda:1')
tensor(-0.0016, device='cuda:1')
tensor(-0.2045, device='cuda:1')
tensor(-0.1570, device='cuda:1')
tensor(0.0061, device='cuda:1')
tensor(0.0483, device='cuda:1')
tensor(0.0128, device='cuda:1')
tensor(-0.1050, device='cuda:1')
tensor(0.0232, device='cuda:1')
tensor(0.0112, device='cuda:1')
tensor(-0.0972, device='cuda:1')
tensor(0.0113, device='cuda:1')
tensor(0.0188, device='cuda:1')
tensor(0.0066, device='cuda:1')
tensor(0.0289, device='cuda:1')
tensor(0.0031, device='cuda:1')
tensor(0.0127, device='cuda:1')
tensor(0.0168, device='cuda:1')
tensor(0.0380, device='cuda:1')
tensor(0.0312, device='cuda:1')
tensor(0.0025, device='cuda:1')
tensor(0.0040, device='cuda:1')
tensor(0.0288, device='cuda:1')
tensor(-0.0181, device='cuda:1')
tensor(0.0284, device='cuda:1')
tensor(0.0001, device='cuda:1')
tensor(0.0080, device='cuda:1')
tensor(0.0131, device='cuda:1')
tensor(0.0422, device='cuda:1')
tensor(0.0065, device='cuda:1')
tensor(0.0054, device='cuda:1')
tensor(0.0063, device='cuda:1')
tensor(0.0334, device='cuda:1')
tensor(-0.0216, device='cuda:1')
tensor(0.0172, device='cuda:1')
tensor(0.0142, device='cuda:1')
tensor(0.0346, device='cuda:1')
tensor(-0.1458, device='cuda:1')
tensor(0.0013, device='cuda:1')
tensor(-0.0066, device='cuda:1')
tensor(-0.0057, device='cuda:1')
tensor(-0.0262, device='cuda:1')
tensor(-0.0565, device='cuda:1')
tensor(0.0048, device='cuda:1')
tensor(0.0166, device='cuda:1')
tensor(0.0130, device='cuda:1')
tensor(0.0086, device='cuda:1')
tensor(0.0101, device='cuda:1')
tensor(-0.0022, device='cuda:1')
tensor(0.0166, device='cuda:1')
tensor(0.0084, device='cuda:1')
train Loss: 29.4669

  8%|██████▌                                                                           | 4/50 [00:09<01:45,  2.30s/it]
test Loss: 64.8366
Epoch 4/49
----------
tensor(0.0088, device='cuda:1')
tensor(0.0334, device='cuda:1')
tensor(-0.0190, device='cuda:1')
tensor(-0.0004, device='cuda:1')
tensor(0.0085, device='cuda:1')
tensor(0.0045, device='cuda:1')
tensor(-0.1080, device='cuda:1')
tensor(0.0116, device='cuda:1')
tensor(0.1879, device='cuda:1')
tensor(0.0074, device='cuda:1')
tensor(-0.0013, device='cuda:1')
tensor(0.0256, device='cuda:1')
tensor(0.0185, device='cuda:1')
tensor(-0.0046, device='cuda:1')
tensor(-0.0110, device='cuda:1')
tensor(0.0334, device='cuda:1')
tensor(-0.0202, device='cuda:1')
tensor(-0.0243, device='cuda:1')
tensor(0.0750, device='cuda:1')
tensor(0.0001, device='cuda:1')
tensor(-0.0007, device='cuda:1')
tensor(0.0196, device='cuda:1')
tensor(-0.0010, device='cuda:1')
tensor(-0.0031, device='cuda:1')
tensor(0.0278, device='cuda:1')
tensor(0.0049, device='cuda:1')
tensor(0.0029, device='cuda:1')
tensor(0.0048, device='cuda:1')
tensor(-0.0077, device='cuda:1')
tensor(0.0469, device='cuda:1')
tensor(0.0099, device='cuda:1')
tensor(-0.0336, device='cuda:1')
tensor(0.0044, device='cuda:1')
tensor(-0.0039, device='cuda:1')
tensor(0.0060, device='cuda:1')
tensor(0.0127, device='cuda:1')
tensor(0.0154, device='cuda:1')
tensor(-0.0037, device='cuda:1')
tensor(-0.0073, device='cuda:1')
tensor(0.0021, device='cuda:1')
tensor(-0.0069, device='cuda:1')
tensor(-0.0382, device='cuda:1')
tensor(-0.0279, device='cuda:1')
tensor(0.0075, device='cuda:1')
tensor(0.0193, device='cuda:1')
tensor(0.0013, device='cuda:1')
tensor(-0.1041, device='cuda:1')
tensor(0.0050, device='cuda:1')
tensor(0.0010, device='cuda:1')
tensor(-0.0567, device='cuda:1')
tensor(0.0106, device='cuda:1')
tensor(0.0071, device='cuda:1')
tensor(0.0066, device='cuda:1')
tensor(0.0083, device='cuda:1')
tensor(-0.0012, device='cuda:1')
tensor(0.0089, device='cuda:1')
tensor(0.0012, device='cuda:1')
tensor(0.0092, device='cuda:1')
tensor(0.0130, device='cuda:1')
tensor(-0.0039, device='cuda:1')
tensor(-0.0065, device='cuda:1')
tensor(0.0064, device='cuda:1')
tensor(-0.0147, device='cuda:1')
tensor(0.0046, device='cuda:1')
tensor(-0.0086, device='cuda:1')
tensor(0.0027, device='cuda:1')
tensor(0.0128, device='cuda:1')
tensor(0.0142, device='cuda:1')
tensor(0.0030, device='cuda:1')
tensor(0.0031, device='cuda:1')
tensor(0.0020, device='cuda:1')
tensor(0.0091, device='cuda:1')
tensor(-0.0189, device='cuda:1')
tensor(0.0282, device='cuda:1')
tensor(0.0078, device='cuda:1')
tensor(0.0126, device='cuda:1')
tensor(-0.0392, device='cuda:1')
tensor(-0.0030, device='cuda:1')
tensor(0.0028, device='cuda:1')
tensor(-0.0062, device='cuda:1')
tensor(-0.0051, device='cuda:1')
tensor(-0.0076, device='cuda:1')
tensor(0.0096, device='cuda:1')
tensor(0.0085, device='cuda:1')
tensor(0.0063, device='cuda:1')
tensor(0.0042, device='cuda:1')
tensor(0.0079, device='cuda:1')
tensor(-0.0007, device='cuda:1')
tensor(0.0088, device='cuda:1')
tensor(0.0034, device='cuda:1')
train Loss: 60.2398

 10%|████████▏                                                                         | 5/50 [00:11<01:44,  2.32s/it]
test Loss: 80.2666
Epoch 5/49
----------
tensor(0.0062, device='cuda:1')
tensor(0.0161, device='cuda:1')
tensor(-0.0009, device='cuda:1')
tensor(0.0035, device='cuda:1')
tensor(0.0055, device='cuda:1')
tensor(0.0051, device='cuda:1')
tensor(-0.0291, device='cuda:1')
tensor(0.0050, device='cuda:1')
tensor(0.1266, device='cuda:1')
tensor(0.0028, device='cuda:1')
tensor(0.0002, device='cuda:1')
tensor(0.0165, device='cuda:1')
tensor(0.0072, device='cuda:1')
tensor(-0.0107, device='cuda:1')
tensor(0.0048, device='cuda:1')
tensor(0.0265, device='cuda:1')
tensor(-0.0109, device='cuda:1')
tensor(-0.0188, device='cuda:1')
tensor(0.0227, device='cuda:1')
tensor(-0.0016, device='cuda:1')
tensor(-0.0035, device='cuda:1')
tensor(0.0055, device='cuda:1')
tensor(-0.0009, device='cuda:1')
tensor(-0.0011, device='cuda:1')
tensor(0.0041, device='cuda:1')
tensor(0.0006, device='cuda:1')
tensor(-0.0042, device='cuda:1')
tensor(0.0020, device='cuda:1')
tensor(-0.0097, device='cuda:1')
tensor(0.0206, device='cuda:1')
tensor(0.0042, device='cuda:1')
tensor(-0.0148, device='cuda:1')
tensor(0.0009, device='cuda:1')
tensor(-0.0029, device='cuda:1')
tensor(0.0024, device='cuda:1')
tensor(0.0068, device='cuda:1')
tensor(0.0066, device='cuda:1')
tensor(-0.0023, device='cuda:1')
tensor(-0.0019, device='cuda:1')
tensor(0.0012, device='cuda:1')
tensor(-0.0023, device='cuda:1')
tensor(-0.0195, device='cuda:1')
tensor(0.0021, device='cuda:1')
tensor(0.0033, device='cuda:1')
tensor(0.0111, device='cuda:1')
tensor(0.0023, device='cuda:1')
tensor(-0.0864, device='cuda:1')
tensor(0.0006, device='cuda:1')
tensor(-0.0004, device='cuda:1')
tensor(-0.0317, device='cuda:1')
tensor(0.0079, device='cuda:1')
tensor(0.0041, device='cuda:1')
tensor(0.0111, device='cuda:1')
tensor(0.0029, device='cuda:1')
tensor(0.0011, device='cuda:1')
tensor(0.0007, device='cuda:1')
tensor(0.0028, device='cuda:1')
tensor(0.0044, device='cuda:1')
tensor(0.0072, device='cuda:1')
tensor(-0.0030, device='cuda:1')
tensor(-0.0096, device='cuda:1')
tensor(0.0034, device='cuda:1')
tensor(-0.0073, device='cuda:1')
tensor(0.0034, device='cuda:1')
tensor(-0.0067, device='cuda:1')
tensor(0.0016, device='cuda:1')
tensor(0.0060, device='cuda:1')
tensor(0.0069, device='cuda:1')
tensor(0.0019, device='cuda:1')
tensor(0.0022, device='cuda:1')
tensor(0.0009, device='cuda:1')
tensor(0.0036, device='cuda:1')
tensor(-0.0110, device='cuda:1')
tensor(0.0179, device='cuda:1')
tensor(0.0045, device='cuda:1')
tensor(0.0057, device='cuda:1')
tensor(-0.0249, device='cuda:1')
tensor(-0.0072, device='cuda:1')
tensor(-0.0002, device='cuda:1')
tensor(-0.0040, device='cuda:1')
tensor(-0.0022, device='cuda:1')
tensor(-0.0074, device='cuda:1')
tensor(0.0053, device='cuda:1')
tensor(0.0034, device='cuda:1')
tensor(0.0034, device='cuda:1')
tensor(0.0026, device='cuda:1')
tensor(0.0033, device='cuda:1')
tensor(0.0007, device='cuda:1')
tensor(0.0052, device='cuda:1')
tensor(0.0018, device='cuda:1')
train Loss: 78.0855
 10%|████████▏                                                                         | 5/50 [00:13<02:02,  2.72s/it]
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
/tmp/ipykernel_1966757/4048788386.py in <module>
     40             # track history if only in train
     41             with torch.set_grad_enabled(phase == 'train'):
---> 42                 X,Xc,Xrppg, proba_fake, proba_real = model(inputs,targets)
     43
     44                 loss = criterion_G(Xc, Xrppg, proba_fake) + criterion_D(proba_fake, proba_real)

~/miniconda3/envs/baseline/lib/python3.9/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1049         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1050                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1051             return forward_call(*input, **kwargs)
   1052         # Do not call functions when jit is used
   1053         full_backward_hooks, non_full_backward_hooks = [], []

/tmp/ipykernel_1966757/3788774083.py in forward(self, X, Xc)
    141
    142     def forward(self, X, Xc):
--> 143         Xrppg = self.G(X)
    144
    145         # Input of original rppg and Ground Truth

~/miniconda3/envs/baseline/lib/python3.9/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1049         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1050                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1051             return forward_call(*input, **kwargs)
   1052         # Do not call functions when jit is used
   1053         full_backward_hooks, non_full_backward_hooks = [], []

/tmp/ipykernel_1966757/3788774083.py in forward(self, x)
     40         res4 = self.activ1(x)
     41
---> 42         x = self.conv5(res4)
     43         res5 = self.activ1(x)
     44

~/miniconda3/envs/baseline/lib/python3.9/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1049         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1050                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1051             return forward_call(*input, **kwargs)
   1052         # Do not call functions when jit is used
   1053         full_backward_hooks, non_full_backward_hooks = [], []

~/miniconda3/envs/baseline/lib/python3.9/site-packages/torch/nn/modules/conv.py in forward(self, input)
    296
    297     def forward(self, input: Tensor) -> Tensor:
--> 298         return self._conv_forward(input, self.weight, self.bias)
    299
    300

~/miniconda3/envs/baseline/lib/python3.9/site-packages/torch/nn/modules/conv.py in _conv_forward(self, input, weight, bias)
    292                             weight, bias, self.stride,
    293                             _single(0), self.dilation, self.groups)
--> 294         return F.conv1d(input, weight, bias, self.stride,
    295                         self.padding, self.dilation, self.groups)
    296

KeyboardInterrupt:
[70]:
plt.plot(train_loss_history[10:])
plt.plot(val_loss_history[10:])
[70]:
[<matplotlib.lines.Line2D at 0x7fb5d5439820>]
../_images/notebooks_train_11_1.png

Examples of usage and results

[71]:
X,Xc = next(iter(testloader))
[72]:
inputs = X.to(device)
targets = Xc.to(device)
[73]:
X,Xc,Xrppg, _, _ = model(inputs,targets)
[74]:
original = X.cpu().numpy()
ground_truth = Xc.cpu().numpy()
corrected = Xrppg.cpu().detach().numpy()
[75]:
fig,ax = plt.subplots(1,3,figsize=(15,5))
ax[0].plot(original[0,0],label='main model rppg')
ax[0].plot(ground_truth[0,0],label='ground truth')
ax[0].plot(corrected[0,0],label='corrected rppg')

ax[1].plot(original[1,0],label='main model rppg')
ax[1].plot(ground_truth[1,0],label='ground truth')
ax[1].plot(corrected[1,0],label='corrected rppg')

ax[2].plot(original[2,0],label='main model rppg')
ax[2].plot(ground_truth[2,0],label='ground truth')
ax[2].plot(corrected[2,0],label='corrected rppg')


ax[0].legend()
fig.suptitle('Illustration of PULSE GAN correction')
[75]:
Text(0.5, 0.98, 'Illustration of PULSE GAN correction')
../_images/notebooks_train_17_1.png
[76]:
eg = pd.read_csv('exp46.csv')
eg['pred'] = normalize(eg['pred'])
eg['gt'] = normalize(eg['gt'])
[77]:
eg.iloc[1500:].plot(figsize=(15,5))
[77]:
<AxesSubplot:>
../_images/notebooks_train_19_1.png
[91]:
model = PulseGan()
model = model.to(device)

model.load_state_dict(torch.load('./checkpoints/PulseGan3_49.pt'))
[91]:
<All keys matched successfully>
[83]:
INPUTS = []
TARGETS = []
for i in range(14):
    preds = np.array(eg['pred'].iloc[i*128:(i+1)*128])
    gts = np.array(eg['gt'].iloc[i*128:(i+1)*128])

    preds = torch.from_numpy(preds).type(torch.FloatTensor)
    gts = torch.from_numpy(gts).type(torch.FloatTensor)
    INPUTS.append(preds)
    TARGETS.append(gts)
[84]:
# Iterate over data.

originals = []
ground_truths = []
correcteds = []

for inputs, targets in zip(INPUTS, TARGETS):

    if inputs.shape[-1] < 128: break

    inputs = inputs.reshape(1,1,-1)
    targets = targets.reshape(1,1,-1)
    inputs = inputs.to(device)
    targets = targets.to(device)

    # forward
    # track history if only in train
    with torch.no_grad():
        X,Xc,Xrppg, proba_fake, proba_real = model(inputs,targets)


    original = inputs.cpu().numpy()
    ground_truth = targets.cpu().numpy()
    corrected = Xrppg.cpu().detach().numpy()

    originals.append(original)
    ground_truths.append(ground_truth)
    correcteds.append(corrected)

originals = np.concatenate(originals).flatten()
correcteds = np.concatenate(correcteds).flatten()
ground_truths = np.concatenate(ground_truths).flatten()
[85]:
fig,ax = plt.subplots(1,figsize=(15,5))

N = len(originals)


ax.plot(np.array(range(N))/29,originals,label='main model rppg',color='grey')
ax.plot(np.array(range(N))/29,ground_truths,label='ground truth',color='red')
ax.plot(np.array(range(N))/29,correcteds,label='corrected rppg',color='black')


ax.legend()
ax.set_xlim(55,62)

fig.suptitle('PULSE GAN correction on ROI motion with L1-match')
[85]:
Text(0.5, 0.98, 'PULSE GAN correction on ROI motion with L1-match')
../_images/notebooks_train_23_1.png

Compute correlation again for Physnet with correcting

[86]:
eg = pd.read_csv('../experiments/csvs/phys_pred2.csv')
[92]:
model.G()
[92]:
Generator(
  (conv1): Conv1d(1, 16, kernel_size=(3,), stride=(1,))
  (conv2): Conv1d(16, 32, kernel_size=(3,), stride=(1,))
  (conv3): Conv1d(32, 64, kernel_size=(3,), stride=(1,))
  (conv4): Conv1d(64, 96, kernel_size=(3,), stride=(1,))
  (conv5): Conv1d(96, 128, kernel_size=(3,), stride=(1,))
  (conv6): Conv1d(128, 156, kernel_size=(3,), stride=(1,))
  (deconv6): ConvTranspose1d(156, 128, kernel_size=(3,), stride=(1,))
  (deconv5): ConvTranspose1d(128, 96, kernel_size=(3,), stride=(1,))
  (deconv4): ConvTranspose1d(96, 64, kernel_size=(3,), stride=(1,))
  (deconv3): ConvTranspose1d(64, 32, kernel_size=(3,), stride=(1,))
  (deconv2): ConvTranspose1d(32, 16, kernel_size=(3,), stride=(1,))
  (deconv1): ConvTranspose1d(16, 1, kernel_size=(3,), stride=(1,))
  (activ1): PReLU(num_parameters=1)
  (activ2): Tanh()
)
[100]:
Xrppg
[100]:
tensor([[[-0.8561, -0.9232, -0.9610, -0.9456, -0.8834, -0.1844,  0.7081,
           0.9840,  0.9986,  0.9845,  0.8769,  0.6075,  0.5246,  0.3487,
           0.4243,  0.3177,  0.1402,  0.0182, -0.2538, -0.5724, -0.7897,
          -0.8779, -0.7846, -0.5494, -0.2307,  0.2559,  0.6254,  0.9330,
           0.9190,  0.8416,  0.8377,  0.8337,  0.7903,  0.6682,  0.4204,
           0.0630,  0.0126, -0.0523, -0.3816, -0.7622, -0.9074, -0.9277,
          -0.8822, -0.7490, -0.3100,  0.5105,  0.9774,  0.9990,  0.9971,
           0.9776,  0.9094,  0.7729,  0.6486,  0.3702,  0.1803,  0.0061,
          -0.1639, -0.3513, -0.7080, -0.8981, -0.9694, -0.9678, -0.9877,
          -0.9966, -0.9951, -0.7126,  0.0204,  0.7281,  0.9817,  0.9862,
           0.9822,  0.9475,  0.8685,  0.6284,  0.4094,  0.0170, -0.0519,
          -0.2035, -0.6461, -0.9499, -0.9724, -0.9755, -0.9884, -0.9782,
          -0.8902, -0.7263, -0.6072,  0.5137,  0.9441,  0.9746,  0.9406,
           0.7834,  0.3989, -0.1910, -0.5312, -0.8172, -0.9524, -0.9942,
          -0.9996, -0.9999, -0.9995, -0.9968, -0.9276,  0.3549,  0.9384,
           0.9945,  0.9960,  0.9944,  0.9764,  0.5528,  0.3854,  0.2980,
          -0.1204, -0.6148, -0.8860, -0.9796, -0.9828, -0.9690, -0.9766,
          -0.9858, -0.8523, -0.5034,  0.2470,  0.5135,  0.4833,  0.6982,
           0.9892,  0.9945]]], device='cuda:1')
[101]:
ret = {
    'id': [],
    'pred': [],
    'gt': [],
    'corrected': []
}
for exp in eg.id.unique():
    interest = eg[eg.id==exp]
    for i in range(len(interest)//128):

        X = np.array(interest['pred'].iloc[i*128:(i+1)*128])
        X = torch.from_numpy(X).type(torch.FloatTensor)
        X = X.reshape(1,1,-1)
        X = X.to(device)

        Xrppg = model.G(X).cpu().detach().numpy().flatten()

        ret['id'] += [exp] * 128
        ret['pred'] += list(X.flatten().cpu().detach().numpy())
        ret['gt'] += list(interest['gt'].iloc[i*128:(i+1)*128])
        ret['corrected'] += list(Xrrpg)
[103]:
corrected_df = pd.DataFrame.from_dict(ret)
[108]:
ret = {
    'id': [],
    'corr': []
}
for exp in corrected_df.id.unique():
    interest = corrected_df[corrected_df.id==exp]
    Xc = interest['gt']
    Xrppg = interest['corrected']


    corr, _ = stats.pearsonr(Xc, Xrppg)
    ret['id'].append(exp)
    ret['corr'].append(corr)
[110]:
pd.DataFrame.from_dict(ret).mean()
/tmp/ipykernel_1948383/2229763696.py:1: FutureWarning: Dropping of nuisance columns in DataFrame reductions (with 'numeric_only=None') is deprecated; in a future version this will raise TypeError.  Select only valid columns before calling the reduction.
  pd.DataFrame.from_dict(ret).mean()
[110]:
corr    0.627004
dtype: float64
[ ]: