本文主要是介绍Snake Conv,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
跟着这位同志的视频将yolo V8 做了替换,很快就跑通了。
Bilibili 飞飞
只不过,Snake Conv对于我的问题上,准确率不升反降。有观察到在一些feature类似,但是又不能算作一个类别的,它会容易混淆。我的问题比较tricky,因为数据量不平衡和数据少(就是力求不平衡和数据少),需要找到一个不需要在意数据量的方法。Emmm…没关系,或许你们的问题可以用得上呢~
除此之外还需要把weight转换成tensorRT, 飞飞视频的代码里有几个地方需要改一下,就能顺利输出TensorRT:
下面的zero, max_y, max_x 需要确保都是在torch cuda上的量。所以需要看着添加.to(device)
def _bilinear_interpolate_3D(self, input_feature, y, x):device = input_feature.devicey = y.reshape([-1]).float()x = x.reshape([-1]).float()zero = torch.zeros([]).int().to(device)max_y = torch.tensor(self.width - 1)max_x = torch.tensor(self.height - 1)# find 8 grid locationsy0 = torch.floor(y).int()y1 = y0 + 1x0 = torch.floor(x).int()x1 = x0 + 1# clip out coordinates exceeding feature map volumey0 = torch.clamp(y0, zero, max_y.to(device))y1 = torch.clamp(y1, zero, max_y.to(device))x0 = torch.clamp(x0, zero, max_x.to(device))x1 = torch.clamp(x1, zero, max_x.to(device))input_feature_flat = input_feature.flatten()input_feature_flat = input_feature_flat.reshape(self.num_batch, self.num_channels, self.width, self.height)input_feature_flat = input_feature_flat.permute(0, 2, 3, 1)input_feature_flat = input_feature_flat.reshape(-1, self.num_channels)dimension = self.height * self.widthbase = torch.arange(self.num_batch) * dimensionbase = base.reshape([-1, 1]).float()repeat = torch.ones([self.num_points * self.width * self.height]).unsqueeze(0)repeat = repeat.float()base = torch.matmul(base, repeat)base = base.reshape([-1])base = base.to(device)base_y0 = base + y0 * self.heightbase_y1 = base + y1 * self.height# top rectangle of the neighbourhood volumeindex_a0 = base_y0 - base + x0index_c0 = base_y0 - base + x1# bottom rectangle of the neighbourhood volumeindex_a1 = base_y1 - base + x0index_c1 = base_y1 - base + x1# get 8 grid valuesvalue_a0 = input_feature_flat[index_a0.type(torch.int64)].to(device)value_c0 = input_feature_flat[index_c0.type(torch.int64)].to(device)value_a1 = input_feature_flat[index_a1.type(torch.int64)].to(device)value_c1 = input_feature_flat[index_c1.type(torch.int64)].to(device)# find 8 grid locationsy0 = torch.floor(y).int()y1 = y0 + 1x0 = torch.floor(x).int()x1 = x0 + 1# clip out coordinates exceeding feature map volumey0 = torch.clamp(y0, zero, max_y.to(device) + 1)y1 = torch.clamp(y1, zero, max_y.to(device) + 1)x0 = torch.clamp(x0, zero, max_x.to(device) + 1)x1 = torch.clamp(x1, zero, max_x.to(device) + 1)x0_float = x0.float()x1_float = x1.float()y0_float = y0.float()y1_float = y1.float()vol_a0 = ((y1_float - y) * (x1_float - x)).unsqueeze(-1).to(device)vol_c0 = ((y1_float - y) * (x - x0_float)).unsqueeze(-1).to(device)vol_a1 = ((y - y0_float) * (x1_float - x)).unsqueeze(-1).to(device)vol_c1 = ((y - y0_float) * (x - x0_float)).unsqueeze(-1).to(device)outputs = (value_a0 * vol_a0 + value_c0 * vol_c0 + value_a1 * vol_a1 +value_c1 * vol_c1)if self.morph == 0:outputs = outputs.reshape([self.num_batch,self.num_points * self.width,1 * self.height,self.num_channels,])outputs = outputs.permute(0, 3, 1, 2)else:outputs = outputs.reshape([self.num_batch,1 * self.width,self.num_points * self.height,self.num_channels,])outputs = outputs.permute(0, 3, 1, 2)return outputsdef deform_conv(self, input, offset, if_offset):y, x = self._coordinate_map_3D(offset, if_offset)deformed_feature = self._bilinear_interpolate_3D(input, y, x)return deformed_feature
Anyway, 反正结构看起来挺颠的
然后我有尝试更改了能改的层,随机的都尝试了一下,最后都没有原始的c2f效果好。
# YOLOv8.0n backbone
backbone:# [from, repeats, module, args]- [-1, 1, Conv, [64, 3, 2]] # 0-P1/2- [-1, 1, Conv, [128, 3, 2]] # 1-P2/4- [-1, 3, C2f_DySnakeConv, [128, True]]- [-1, 1, Conv, [256, 3, 2]] # 3-P3/8- [-1, 6, C2f_DySnakeConv, [256, True]]- [-1, 1, Conv, [512, 3, 2]] # 5-P4/16- [-1, 6, C2f_DySnakeConv, [512, True]]- [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32- [-1, 3, C2f_DySnakeConv, [1024, True]]- [-1, 1, SPPF, [1024, 5]] # 9# YOLOv8.0n head
head:- [-1, 1, nn.Upsample, [None, 2, 'nearest']]- [[-1, 6], 1, Concat, [1]] # cat backbone P4- [-1, 3, C2f_DySnakeConv, [512]] # 12- [-1, 1, nn.Upsample, [None, 2, 'nearest']]- [[-1, 4], 1, Concat, [1]] # cat backbone P3- [-1, 3, C2f_DySnakeConv, [256]] # 15 (P3/8-small)- [-1, 1, Conv, [256, 3, 2]]- [[-1, 12], 1, Concat, [1]] # cat head P4- [-1, 3, C2f_DySnakeConv, [512]] # 18 (P4/16-medium)- [-1, 1, Conv, [512, 3, 2]]- [[-1, 9], 1, Concat, [1]] # cat head P5- [-1, 3, C2f_DySnakeConv, [1024]] # 21 (P5/32-large)- [[15, 18, 21], 1, Detect, [nc]] # Detect(P3, P4, P5)
后续我会再check一遍snake conv的论文,顺便猜测or理解一下,为啥在我的问题上,它效果不太行~ (除了数据量之外)
这篇关于Snake Conv的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!