Segmentation Training
The team implemented the DeepLab v3 model from Google's DeepMind as our semantic segmentation model. This model has achieved great success across domains in numerous publications through its use of atrous convolution and spatial pyramid pooling. DeepMind is available with the TorchVision library and the team made 2 significant design choices to optimize performance. First, we implemented the ResNet50 backbone of the model, which strikes a good balance between model size and model accuracy. Next, we added an 8 node classification head to the model to align with our 8 segmented classes: background1, background 2, stage, Sonic, robots, items, hazards, and mechanicals.
The team elected to instantiate weights from a model pretrained on ImageNet images, hypothesizing that there may be some features and artifacts that carry over from natural images. This model had a mean intersection over union (mIoU) of 0.4385 at the first epoch, validating our hypothesis of starting with a pretrained model. The team then generated a dataset of 40,000 synthetic images containing facets of Sonic environments from all Acts. These images were split into a dataset of 80% training images and 20% validation images. The model trained on this dataset for 45 epochs to reach a final performance of 0.7214 mIoU. A simple test function was created to run “eye tests” of the models performance on unseen images as seen below.

Our Code:
source/vision/deeplab.py
1##---------------Source-------------------------##
2# Montalvo, J., García-Martín, Á. & Bescós, J. Exploiting semantic segmentation to boost reinforcement learning in video game environments.
3# Multimed Tools Appl (2022). https://doi-org.ezproxy.lib.vt.edu/10.1007/s11042-022-13695-1import
4##---------------Source-------------------------##
5
6import os
7import sys
8os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
9
10script_dir = os.path.dirname(os.path.abspath(__file__))
11project_dir = os.path.abspath(script_dir + "/../..")
12
13sys.path.append(os.path.abspath(project_dir + '/source/datasets')) # add learning directory
14
15from PIL import Image
16import numpy as np
17import random
18import matplotlib.pyplot as plt
19import time
20import torch
21import torch.nn as nn
22from torch import optim
23from torchvision import models, transforms
24import torch.nn.functional as F
25from torchvision.models.segmentation.deeplabv3 import DeepLabHead
26from torchvision import models
27
28from image_tuple import *
29
30class DeepLab:
31 def __init__(self, weight_file=None):
32
33 self.pre_load = "True" ## Load dataset in memory
34 self.pre_trained = "True"
35 self.num_classes = 6
36 self.ignore_label = 255
37 self.lr = 0.001 # 0.001 if pretrained weights from pytorch. 0.1 if scratch
38 self.M = [37,42] # If training from scratch, reduce learning rate at some point
39
40 self.seed = 42
41
42 ## Create arguments object
43 self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
44
45 # Set random seed for reproducibility
46 torch.backends.cudnn.deterministic = True # fix the GPU to deterministic mode
47 torch.manual_seed(self.seed) # CPU seed
48 torch.cuda.manual_seed_all(self.seed) # GPU seed
49 random.seed(self.seed) # python seed for image transformation
50 np.random.seed(self.seed)
51
52 self.workers = 0 #Anything over 0 will crash on windows. On linux it should work fine.
53
54 model = models.segmentation.deeplabv3_resnet50(
55 weights='DeepLabV3_ResNet50_Weights.COCO_WITH_VOC_LABELS_V1',
56 progress=True)
57 model.classifier = DeepLabHead(2048, 8) # Num Classes
58 if weight_file is not None:
59 model.load_state_dict(torch.load(weight_file, map_location=torch.device(self.device)))
60 model = model.to(self.device)
61 self.model=model
62 # self.optimizer = optim.SGD(model.parameters(), lr=self.lr, momentum=0.9, weight_decay=1e-4)
63 self.optimizer = optim.Adam(model.parameters(), lr=self.lr)
64
65 self.scheduler = optim.lr_scheduler.MultiStepLR(self.optimizer, milestones=self.M, gamma=0.1)
66
67
68 def train_epoch(self, args, train_loader):
69 # switch to train mode
70 self.model.train()
71
72 train_loss = []
73 counter = 1
74
75 criterion = nn.CrossEntropyLoss(ignore_index=255)
76
77 for _, (images, mask) in enumerate(train_loader):
78
79 images, mask = images.to(self.device), mask.to(self.device)
80
81 outputs = self.model(images)['out']
82
83 #Aggregated per-pixel loss
84 loss = criterion(outputs, mask.squeeze(1))
85 train_loss.append(loss.item())
86
87 self.optimizer.zero_grad()
88
89 loss.backward()
90
91 self.optimizer.step()
92
93 if counter % 15 == 0:
94 print('Train Epoch: {} [{}/{} ({:.0f}%)] Loss: {:.6f}, Learning rate: {:.6f}'.format(
95 args.epochs, int(counter * len(images)), len(train_loader.dataset),
96 100. * counter / len(train_loader), loss.item(),
97 self.optimizer.param_groups[0]['lr']))
98 counter = counter + 1
99
100 return sum(train_loss) / len(train_loss) # per batch averaged loss for the current epoch.
101
102 def _fast_hist(self, label_pred, label_true, num_classes):
103 mask = (label_true >= 0) & (label_true < num_classes)
104 hist = np.bincount(
105 num_classes * label_true[mask].astype(int) +
106 label_pred[mask], minlength=num_classes ** 2).reshape(num_classes, num_classes)
107 return hist
108
109 def testing(self, test_loader):
110
111 self.model.eval()
112
113 loss_per_batch = []
114
115 criterion = nn.CrossEntropyLoss(ignore_index=255)
116
117 gts_all, predictions_all = [], []
118 with torch.no_grad():
119 for _, (images, mask) in enumerate(test_loader):
120
121 images, mask = images.to(self.device), mask.to(self.device)
122
123 outputs = self.model(images)['out']
124
125 loss = criterion(outputs,mask.squeeze(1))
126 loss_per_batch.append(loss.item())
127
128 # Adapt output size for histogram calculation.
129 preds = outputs.data.max(1)[1].squeeze(1).squeeze(0).cpu().numpy()
130 gts_all.append(mask.data.squeeze(0).cpu().numpy())
131 predictions_all.append(preds)
132
133 loss_per_epoch = [np.average(loss_per_batch)]
134
135 hist = np.zeros((self.num_classes, self.num_classes))
136 for lp, lt in zip(predictions_all, gts_all):
137 hist += self._fast_hist(lp.flatten(), lt.flatten(), self.num_classes)
138
139 iou = np.diag(hist) / (hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist))
140
141
142 mean_iou = np.nanmean(iou)
143
144 print('
145Test set ({:.0f}): Average loss: {:.4f}, mIoU: {:.4f}
146'.format(
147 len(test_loader.dataset), loss_per_epoch[-1], mean_iou))
148
149 return (loss_per_epoch, mean_iou)
150
151 def decode_segmap(self, image, nc=8):
152 ## Color palette for visualization of the 21 classes
153 label_colors = np.array([(0, 0, 0), # 0=background
154 # 1=aeroplane, 2=bicycle, 3=bird, 4=boat, 5=bottle
155 (0, 0,255), (127, 127, 0), (0, 255, 0), (255, 0, 0), (255, 255, 0),
156 # 6=bus, 7=car, 8=cat, 9=chair, 10=cow
157 (0, 128, 128), (128, 128, 128), (64, 0, 0), (192, 0, 0), (64, 128, 0),
158 # 11=dining table, 12=dog, 13=horse, 14=motorbike, 15=person
159 (192, 128, 0), (64, 0, 128), (192, 0, 128), (64, 128, 128), (192, 128, 128),
160 # 16=potted plant, 17=sheep, 18=sofa, 19=train, 20=tv/monitor
161 (0, 64, 0), (128, 64, 0), (0, 192, 0), (128, 192, 0), (0, 64, 128)])
162
163 r = np.zeros_like(image).astype(np.uint8)
164 g = np.zeros_like(image).astype(np.uint8)
165 b = np.zeros_like(image).astype(np.uint8)
166 for l in range(0, nc):
167 idx = image == l
168 r[idx] = label_colors[l, 0]
169 g[idx] = label_colors[l, 1]
170 b[idx] = label_colors[l, 2]
171 rgb = np.stack([r, g, b], axis=2)
172 return rgb
173
174 def seg_test(self, path, transform=transforms.ToTensor()):
175 img = Image.open(path).convert('RGB')
176
177 input_image = transform(img).unsqueeze(0).to(self.device)
178 self.model.eval()
179 timer = time.time()
180 out = self.model(input_image)['out'][0]
181 print (f'Segmentation Time: {time.time()-timer}')
182
183 segm = torch.argmax(out.squeeze(), dim=0).detach().cpu().numpy()
184 segm_rgb = self.decode_segmap(segm)
185 fig = plt.figure()
186 fig.add_subplot(1, 2, 1)
187 plt.imshow(img);plt.axis('off')
188 fig.add_subplot(1, 2, 2)
189 plt.imshow(segm_rgb);plt.axis('off')
190 #plt.savefig('1_1.png', format='png',dpi=300,bbox_inches = "tight")
191 plt.show()
192
193 def segment(self, image, transform=transforms.ToTensor()):
194 input_image = transform(image).unsqueeze(0).to(self.device)
195 timer = time.time()
196 out = self.model(input_image)['out'][0]
197 print (f'Segmentation Time: {time.time()-timer}')
198 segm = torch.argmax(out.squeeze(), dim=0).detach().cpu().numpy()
199
200 return segm
source/drivers/deeplab_train.py
1##---------------Source-------------------------##
2# Montalvo, J., García-Martín, Á. & Bescós, J. Exploiting semantic segmentation to boost reinforcement learning in video game environments.
3# Multimed Tools Appl (2022). https://doi-org.ezproxy.lib.vt.edu/10.1007/s11042-022-13695-1import
4##---------------Source-------------------------##
5import os
6import sys
7
8os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
9script_dir = os.path.dirname(os.path.abspath(__file__))
10project_dir = os.path.abspath(script_dir + "/../..")
11
12
13sys.path.append(os.path.abspath(project_dir + '/source/datasets'))
14sys.path.append(os.path.abspath(project_dir + '/source/vision'))
15
16from deeplab import *
17from deeplab_dataset import *
18
19from PIL import Image
20import numpy as np
21import random
22import argparse
23import time
24from os.path import join
25from tqdm import tqdm
26
27import torch
28
29
30parser = argparse.ArgumentParser()
31parser.add_argument("-m","--model",default=None,type=str, help="Name of a partially trained model. Training will continue to optimize these set of weights.")
32parser.add_argument("-o","--output_file",default="SegmentationModel",type=str, help="Name of the model. Will be saved on results/deeplab_ckpts")
33parser.add_argument("-bs","--batch_size",default=4, choices=range(2,32),type=int, help="Keep it always 2 or more, otherwise it will crash.") ##
34parser.add_argument("-d","--dataset", default='data/segmentation_dataset', help="Path to dataset",type=str)
35parser.add_argument("-e","--epochs",default=45,type=int,help="Epochs")
36args = parser.parse_args()
37
38def main():
39
40 deep_lab = DeepLab(args.model)
41 trainset = SonicDataset(args, 'train')
42 train_loader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=deep_lab.workers, pin_memory=True)
43
44 testset = SonicDataset(args, 'val')
45 test_loader = torch.utils.data.DataLoader(testset, batch_size=args.batch_size, shuffle=False, num_workers=deep_lab.workers, pin_memory=True)
46
47 loss_train_epoch = []
48 loss_test_epoch = []
49 acc_train_per_epoch = []
50 acc_test_per_epoch = []
51 new_labels = []
52 path = os.path.join(project_dir, "results", "deeplab_ckpts")
53 if not os.path.isdir(path):
54 os.makedirs(path)
55
56 for epoch in tqdm(range(1, args.epochs + 1), desc = f"DeepLabV3_Resnet50 training"):
57 st = time.time()
58 loss_per_epoch = deep_lab.train_epoch(args, train_loader)
59
60 loss_train_epoch += [loss_per_epoch]
61
62 deep_lab.scheduler.step()
63
64 loss_per_epoch_test, acc_val_per_epoch_i = deep_lab.testing(test_loader)
65
66 loss_test_epoch += loss_per_epoch_test
67 acc_test_per_epoch += [acc_val_per_epoch_i]
68
69 if epoch == 1:
70 best_acc_val = acc_val_per_epoch_i
71
72 else:
73 if acc_val_per_epoch_i > best_acc_val:
74 best_acc_val = acc_val_per_epoch_i
75
76
77 torch.save(deep_lab.model.state_dict(), os.path.join(path, f'{args.output_file}_{epoch}.pt'))
78
79
80# Call main
81main()
source/drivers/deeplab_test.py
1##---------------Source-------------------------##
2# Montalvo, J., García-Martín, Á. & Bescós, J. Exploiting semantic segmentation to boost reinforcement learning in video game environments.
3# Multimed Tools Appl (2022). https://doi-org.ezproxy.lib.vt.edu/10.1007/s11042-022-13695-1import
4##---------------Source-------------------------##
5import os
6import sys
7
8script_dir = os.path.dirname(os.path.abspath(__file__))
9project_dir = os.path.abspath(script_dir + "/../..")
10
11
12sys.path.append(os.path.abspath(project_dir + '/source/datasets'))
13sys.path.append(os.path.abspath(project_dir + '/source/vision'))
14
15from deeplab import *
16from deeplab_dataset import *
17import argparse
18
19
20
21
22parser = argparse.ArgumentParser()
23parser.add_argument("-m","--model",default='results/deeplab_ckpts/SegmentationModel.pt',type=str, help="Name of the model. Will be saved on results/deeplab_ckpts")
24parser.add_argument("-i","--image",required=True, help="Path to image",type=str)
25args = parser.parse_args()
26
27def main():
28 seg = DeepLab(args.model)
29 seg.seg_test(os.path.join(project_dir, args.image))
30# Call main
31main()