Segmentation Training

The team implemented the DeepLab v3 model from Google's DeepMind as our semantic segmentation model. This model has achieved great success across domains in numerous publications through its use of atrous convolution and spatial pyramid pooling. DeepMind is available with the TorchVision library and the team made 2 significant design choices to optimize performance. First, we implemented the ResNet50 backbone of the model, which strikes a good balance between model size and model accuracy. Next, we added an 8 node classification head to the model to align with our 8 segmented classes: background1, background 2, stage, Sonic, robots, items, hazards, and mechanicals.

The team elected to instantiate weights from a model pretrained on ImageNet images, hypothesizing that there may be some features and artifacts that carry over from natural images. This model had a mean intersection over union (mIoU) of 0.4385 at the first epoch, validating our hypothesis of starting with a pretrained model. The team then generated a dataset of 40,000 synthetic images containing facets of Sonic environments from all Acts. These images were split into a dataset of 80% training images and 20% validation images. The model trained on this dataset for 45 epochs to reach a final performance of 0.7214 mIoU. A simple test function was created to run “eye tests” of the models performance on unseen images as seen below.

Preprocessed sonic image

Our Code:

source/vision/deeplab.py

1##---------------Source-------------------------##
2# Montalvo, J., García-Martín, Á. & Bescós, J. Exploiting semantic segmentation to boost reinforcement learning in video game environments. 
3# Multimed Tools Appl (2022). https://doi-org.ezproxy.lib.vt.edu/10.1007/s11042-022-13695-1import 
4##---------------Source-------------------------##
5
6import os
7import sys
8os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
9
10script_dir = os.path.dirname(os.path.abspath(__file__))
11project_dir = os.path.abspath(script_dir + "/../..")
12
13sys.path.append(os.path.abspath(project_dir + '/source/datasets'))	# add learning directory
14
15from PIL import Image
16import numpy as np
17import random
18import matplotlib.pyplot as plt
19import time
20import torch
21import torch.nn as nn
22from torch import optim
23from torchvision import  models, transforms
24import torch.nn.functional as F
25from torchvision.models.segmentation.deeplabv3 import DeepLabHead
26from torchvision import models
27
28from image_tuple import * 
29
30class DeepLab:
31    def __init__(self, weight_file=None):
32    
33        self.pre_load    = "True" ## Load dataset in memory
34        self.pre_trained = "True"
35        self.num_classes = 6
36        self.ignore_label = 255
37        self.lr    = 0.001  # 0.001 if pretrained weights from pytorch. 0.1 if scratch
38        self.M = [37,42]         # If training from scratch, reduce learning rate at some point        
39        
40        self.seed = 42
41
42        ## Create arguments object
43        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
44
45        # Set random seed for reproducibility
46        torch.backends.cudnn.deterministic = True  # fix the GPU to deterministic mode
47        torch.manual_seed(self.seed)  # CPU seed
48        torch.cuda.manual_seed_all(self.seed)  # GPU seed
49        random.seed(self.seed)  # python seed for image transformation
50        np.random.seed(self.seed)
51
52        self.workers = 0 #Anything over 0 will crash on windows. On linux it should work fine.
53
54        model = models.segmentation.deeplabv3_resnet50(
55                weights='DeepLabV3_ResNet50_Weights.COCO_WITH_VOC_LABELS_V1',
56                progress=True)
57        model.classifier = DeepLabHead(2048, 8) # Num Classes
58        if weight_file is not None:
59            model.load_state_dict(torch.load(weight_file, map_location=torch.device(self.device)))
60        model = model.to(self.device)
61        self.model=model
62        # self.optimizer = optim.SGD(model.parameters(), lr=self.lr, momentum=0.9, weight_decay=1e-4)
63        self.optimizer = optim.Adam(model.parameters(),  lr=self.lr)
64
65        self.scheduler = optim.lr_scheduler.MultiStepLR(self.optimizer, milestones=self.M, gamma=0.1)
66
67
68    def train_epoch(self, args, train_loader):
69        # switch to train mode
70        self.model.train()
71
72        train_loss = []
73        counter = 1
74
75        criterion = nn.CrossEntropyLoss(ignore_index=255)
76        
77        for _, (images, mask) in enumerate(train_loader):
78
79            images, mask = images.to(self.device), mask.to(self.device)
80
81            outputs = self.model(images)['out']
82    
83            #Aggregated per-pixel loss
84            loss = criterion(outputs, mask.squeeze(1))
85            train_loss.append(loss.item())
86
87            self.optimizer.zero_grad()
88
89            loss.backward()
90
91            self.optimizer.step()
92
93            if counter % 15 == 0:
94                print('Train Epoch: {} [{}/{} ({:.0f}%)]	Loss: {:.6f}, Learning rate: {:.6f}'.format(
95                    args.epochs, int(counter * len(images)), len(train_loader.dataset),
96                    100. * counter / len(train_loader), loss.item(),
97                    self.optimizer.param_groups[0]['lr']))
98            counter = counter + 1
99        
100        return sum(train_loss) / len(train_loss) # per batch averaged loss for the current epoch.
101
102    def _fast_hist(self, label_pred, label_true, num_classes):
103        mask = (label_true >= 0) & (label_true < num_classes)
104        hist = np.bincount(
105            num_classes * label_true[mask].astype(int) +
106            label_pred[mask], minlength=num_classes ** 2).reshape(num_classes, num_classes)
107        return hist
108
109    def testing(self, test_loader):
110
111        self.model.eval()
112
113        loss_per_batch = []
114
115        criterion = nn.CrossEntropyLoss(ignore_index=255)
116
117        gts_all, predictions_all = [], []
118        with torch.no_grad():
119            for _, (images, mask) in enumerate(test_loader):
120
121                images, mask = images.to(self.device), mask.to(self.device)
122
123                outputs = self.model(images)['out']
124
125                loss = criterion(outputs,mask.squeeze(1))
126                loss_per_batch.append(loss.item())
127
128                # Adapt output size for histogram calculation.
129                preds = outputs.data.max(1)[1].squeeze(1).squeeze(0).cpu().numpy()
130                gts_all.append(mask.data.squeeze(0).cpu().numpy())
131                predictions_all.append(preds)
132
133        loss_per_epoch = [np.average(loss_per_batch)]
134
135        hist = np.zeros((self.num_classes, self.num_classes))
136        for lp, lt in zip(predictions_all, gts_all):
137            hist += self._fast_hist(lp.flatten(), lt.flatten(), self.num_classes)
138
139        iou = np.diag(hist) / (hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist))
140
141
142        mean_iou = np.nanmean(iou)
143
144        print('
145Test set ({:.0f}): Average loss: {:.4f}, mIoU: {:.4f}
146'.format(
147            len(test_loader.dataset), loss_per_epoch[-1], mean_iou))
148
149        return (loss_per_epoch, mean_iou)
150
151    def decode_segmap(self, image, nc=8):
152        ## Color palette for visualization of the 21 classes
153        label_colors = np.array([(0, 0, 0),  # 0=background
154                    # 1=aeroplane, 2=bicycle, 3=bird, 4=boat, 5=bottle
155                    (0, 0,255), (127, 127, 0), (0, 255, 0), (255, 0, 0), (255, 255, 0),
156                    # 6=bus, 7=car, 8=cat, 9=chair, 10=cow
157                    (0, 128, 128), (128, 128, 128), (64, 0, 0), (192, 0, 0), (64, 128, 0),
158                    # 11=dining table, 12=dog, 13=horse, 14=motorbike, 15=person
159                    (192, 128, 0), (64, 0, 128), (192, 0, 128), (64, 128, 128), (192, 128, 128),
160                    # 16=potted plant, 17=sheep, 18=sofa, 19=train, 20=tv/monitor
161                    (0, 64, 0), (128, 64, 0), (0, 192, 0), (128, 192, 0), (0, 64, 128)])
162
163        r = np.zeros_like(image).astype(np.uint8)
164        g = np.zeros_like(image).astype(np.uint8)
165        b = np.zeros_like(image).astype(np.uint8)
166        for l in range(0, nc):
167            idx = image == l
168            r[idx] = label_colors[l, 0]
169            g[idx] = label_colors[l, 1]
170            b[idx] = label_colors[l, 2]
171        rgb = np.stack([r, g, b], axis=2)
172        return rgb
173
174    def seg_test(self, path, transform=transforms.ToTensor()):
175        img = Image.open(path).convert('RGB')
176        
177        input_image = transform(img).unsqueeze(0).to(self.device)
178        self.model.eval()
179        timer = time.time()
180        out = self.model(input_image)['out'][0]
181        print (f'Segmentation Time: {time.time()-timer}')
182
183        segm = torch.argmax(out.squeeze(), dim=0).detach().cpu().numpy()
184        segm_rgb = self.decode_segmap(segm)
185        fig = plt.figure()
186        fig.add_subplot(1, 2, 1)
187        plt.imshow(img);plt.axis('off')
188        fig.add_subplot(1, 2, 2)
189        plt.imshow(segm_rgb);plt.axis('off')
190        #plt.savefig('1_1.png', format='png',dpi=300,bbox_inches = "tight")
191        plt.show()
192            
193    def segment(self, image, transform=transforms.ToTensor()):
194        input_image = transform(image).unsqueeze(0).to(self.device)
195        timer = time.time()
196        out = self.model(input_image)['out'][0]
197        print (f'Segmentation Time: {time.time()-timer}')
198        segm = torch.argmax(out.squeeze(), dim=0).detach().cpu().numpy() 
199
200        return segm

source/drivers/deeplab_train.py

1##---------------Source-------------------------##
2# Montalvo, J., García-Martín, Á. & Bescós, J. Exploiting semantic segmentation to boost reinforcement learning in video game environments. 
3# Multimed Tools Appl (2022). https://doi-org.ezproxy.lib.vt.edu/10.1007/s11042-022-13695-1import 
4##---------------Source-------------------------##
5import os
6import sys
7
8os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
9script_dir = os.path.dirname(os.path.abspath(__file__))
10project_dir = os.path.abspath(script_dir + "/../..")
11
12
13sys.path.append(os.path.abspath(project_dir + '/source/datasets'))
14sys.path.append(os.path.abspath(project_dir + '/source/vision'))
15
16from deeplab import *
17from deeplab_dataset import *
18
19from PIL import Image
20import numpy as np
21import random
22import argparse
23import time
24from os.path import join
25from tqdm import tqdm
26
27import torch
28
29
30parser = argparse.ArgumentParser()
31parser.add_argument("-m","--model",default=None,type=str, help="Name of a partially trained model. Training will continue to optimize these set of weights.")
32parser.add_argument("-o","--output_file",default="SegmentationModel",type=str, help="Name of the model. Will be saved on results/deeplab_ckpts")
33parser.add_argument("-bs","--batch_size",default=4, choices=range(2,32),type=int, help="Keep it always 2 or more, otherwise it will crash.") ## 
34parser.add_argument("-d","--dataset", default='data/segmentation_dataset', help="Path to dataset",type=str)
35parser.add_argument("-e","--epochs",default=45,type=int,help="Epochs")
36args = parser.parse_args()
37
38def main():
39
40    deep_lab = DeepLab(args.model)
41    trainset = SonicDataset(args, 'train')
42    train_loader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=deep_lab.workers, pin_memory=True)
43
44    testset = SonicDataset(args, 'val')
45    test_loader = torch.utils.data.DataLoader(testset, batch_size=args.batch_size, shuffle=False, num_workers=deep_lab.workers, pin_memory=True)
46
47    loss_train_epoch = []
48    loss_test_epoch = []
49    acc_train_per_epoch = []
50    acc_test_per_epoch = []
51    new_labels = []
52    path = os.path.join(project_dir, "results", "deeplab_ckpts")
53    if not os.path.isdir(path):
54        os.makedirs(path)
55
56    for epoch in tqdm(range(1, args.epochs + 1), desc = f"DeepLabV3_Resnet50 training"):
57        st = time.time()
58        loss_per_epoch = deep_lab.train_epoch(args, train_loader)
59
60        loss_train_epoch += [loss_per_epoch]
61
62        deep_lab.scheduler.step()
63
64        loss_per_epoch_test, acc_val_per_epoch_i = deep_lab.testing(test_loader)
65
66        loss_test_epoch += loss_per_epoch_test
67        acc_test_per_epoch += [acc_val_per_epoch_i]
68
69        if epoch == 1:
70            best_acc_val = acc_val_per_epoch_i
71            
72        else:
73            if acc_val_per_epoch_i > best_acc_val:
74                best_acc_val = acc_val_per_epoch_i
75
76        
77        torch.save(deep_lab.model.state_dict(), os.path.join(path, f'{args.output_file}_{epoch}.pt'))
78
79
80# Call main
81main()

source/drivers/deeplab_test.py

1##---------------Source-------------------------##
2# Montalvo, J., García-Martín, Á. & Bescós, J. Exploiting semantic segmentation to boost reinforcement learning in video game environments. 
3# Multimed Tools Appl (2022). https://doi-org.ezproxy.lib.vt.edu/10.1007/s11042-022-13695-1import 
4##---------------Source-------------------------##
5import os
6import sys
7
8script_dir = os.path.dirname(os.path.abspath(__file__))
9project_dir = os.path.abspath(script_dir + "/../..")
10
11
12sys.path.append(os.path.abspath(project_dir + '/source/datasets'))
13sys.path.append(os.path.abspath(project_dir + '/source/vision'))
14
15from deeplab import *
16from deeplab_dataset import *
17import argparse
18
19
20
21
22parser = argparse.ArgumentParser()
23parser.add_argument("-m","--model",default='results/deeplab_ckpts/SegmentationModel.pt',type=str, help="Name of the model. Will be saved on results/deeplab_ckpts")
24parser.add_argument("-i","--image",required=True, help="Path to image",type=str)
25args = parser.parse_args()
26
27def main():
28    seg = DeepLab(args.model)
29    seg.seg_test(os.path.join(project_dir, args.image))
30# Call main
31main()