Update scaled anchore dtype and fix cuda as device

This commit is contained in:
sosokker 2024-05-13 14:42:47 +07:00
parent 1894ff103b
commit 7f0c0baf60

View File

@ -118,7 +118,7 @@ class YOLOLayer(nn.Module):
# Calculate offsets for each grid
self.grid_x = torch.arange(g).repeat(g, 1).view([1, 1, g, g]).type(FloatTensor)
self.grid_y = torch.arange(g).repeat(g, 1).t().view([1, 1, g, g]).type(FloatTensor)
self.scaled_anchors = FloatTensor([(a_w / self.stride, a_h / self.stride) for a_w, a_h in self.anchors])
self.scaled_anchors = torch.tensor([(a_w / self.stride, a_h / self.stride) for a_w, a_h in self.anchors], dtype=torch.float32, device='cuda')
self.anchor_w = self.scaled_anchors[:, 0:1].view((1, self.num_anchors, 1, 1))
self.anchor_h = self.scaled_anchors[:, 1:2].view((1, self.num_anchors, 1, 1))