-
Notifications
You must be signed in to change notification settings - Fork 386
Expand file tree
/
Copy pathocr_utils.py
More file actions
158 lines (135 loc) · 6.02 KB
/
ocr_utils.py
File metadata and controls
158 lines (135 loc) · 6.02 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
# ------------------------------------------------------------------------------
# Copyright (c) Microsoft
# Licensed under the MIT License.
# Written by Ke Sun ([email protected]), Jingyi Xie ([email protected])
#
# This code is from: https://github.com/HRNet/HRNet-Semantic-Segmentation
# ------------------------------------------------------------------------------
import torch
import torch.nn as nn
import torch.nn.functional as F
from config import cfg
from network.utils import BNReLU
class SpatialGather_Module(nn.Module):
"""
Aggregate the context features according to the initial
predicted probability distribution.
Employ the soft-weighted method to aggregate the context.
Output:
The correlation of every class map with every feature map
shape = [n, num_feats, num_classes, 1]
"""
def __init__(self, cls_num=0, scale=1):
super(SpatialGather_Module, self).__init__()
self.cls_num = cls_num
self.scale = scale
def forward(self, feats, probs):
batch_size, c, _, _ = probs.size(0), probs.size(1), probs.size(2), \
probs.size(3)
# each class image now a vector
probs = probs.view(batch_size, c, -1)
feats = feats.view(batch_size, feats.size(1), -1)
feats = feats.permute(0, 2, 1) # batch x hw x c
probs = F.softmax(self.scale * probs, dim=2) # batch x k x hw
ocr_context = torch.matmul(probs, feats)
ocr_context = ocr_context.permute(0, 2, 1).unsqueeze(3)
return ocr_context
class ObjectAttentionBlock(nn.Module):
'''
The basic implementation for object context block
Input:
N X C X H X W
Parameters:
in_channels : the dimension of the input feature map
key_channels : the dimension after the key/query transform
scale : choose the scale to downsample the input feature
maps (save memory cost)
Return:
N X C X H X W
'''
def __init__(self, in_channels, key_channels, scale=1):
super(ObjectAttentionBlock, self).__init__()
self.scale = scale
self.in_channels = in_channels
self.key_channels = key_channels
self.pool = nn.MaxPool2d(kernel_size=(scale, scale))
self.f_pixel = nn.Sequential(
nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels,
kernel_size=1, stride=1, padding=0, bias=False),
BNReLU(self.key_channels),
nn.Conv2d(in_channels=self.key_channels, out_channels=self.key_channels,
kernel_size=1, stride=1, padding=0, bias=False),
BNReLU(self.key_channels),
)
self.f_object = nn.Sequential(
nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels,
kernel_size=1, stride=1, padding=0, bias=False),
BNReLU(self.key_channels),
nn.Conv2d(in_channels=self.key_channels, out_channels=self.key_channels,
kernel_size=1, stride=1, padding=0, bias=False),
BNReLU(self.key_channels),
)
self.f_down = nn.Sequential(
nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels,
kernel_size=1, stride=1, padding=0, bias=False),
BNReLU(self.key_channels),
)
self.f_up = nn.Sequential(
nn.Conv2d(in_channels=self.key_channels, out_channels=self.in_channels,
kernel_size=1, stride=1, padding=0, bias=False),
BNReLU(self.in_channels),
)
def forward(self, x, proxy):
batch_size, h, w = x.size(0), x.size(2), x.size(3)
if self.scale > 1:
x = self.pool(x)
query = self.f_pixel(x).view(batch_size, self.key_channels, -1)
query = query.permute(0, 2, 1)
key = self.f_object(proxy).view(batch_size, self.key_channels, -1)
value = self.f_down(proxy).view(batch_size, self.key_channels, -1)
value = value.permute(0, 2, 1)
sim_map = torch.matmul(query, key)
sim_map = (self.key_channels**-.5) * sim_map
sim_map = F.softmax(sim_map, dim=-1)
# add bg context ...
context = torch.matmul(sim_map, value)
context = context.permute(0, 2, 1).contiguous()
context = context.view(batch_size, self.key_channels, *x.size()[2:])
context = self.f_up(context)
if self.scale > 1:
context = F.interpolate(input=context, size=(h, w), mode='bilinear',
align_corners=cfg.MODEL.ALIGN_CORNERS)
return context
class SpatialOCR_Module(nn.Module):
"""
Implementation of the OCR module:
We aggregate the global object representation to update the representation
for each pixel.
"""
def __init__(self, in_channels, key_channels, out_channels, scale=1,
dropout=0.1):
super(SpatialOCR_Module, self).__init__()
self.object_context_block = ObjectAttentionBlock(in_channels,
key_channels,
scale)
if cfg.MODEL.OCR_ASPP:
self.aspp, aspp_out_ch = get_aspp(
in_channels, bottleneck_ch=cfg.MODEL.ASPP_BOT_CH,
output_stride=8)
_in_channels = 2 * in_channels + aspp_out_ch
else:
_in_channels = 2 * in_channels
self.conv_bn_dropout = nn.Sequential(
nn.Conv2d(_in_channels, out_channels, kernel_size=1, padding=0,
bias=False),
BNReLU(out_channels),
nn.Dropout2d(dropout)
)
def forward(self, feats, proxy_feats):
context = self.object_context_block(feats, proxy_feats)
if cfg.MODEL.OCR_ASPP:
aspp = self.aspp(feats)
output = self.conv_bn_dropout(torch.cat([context, aspp, feats], 1))
else:
output = self.conv_bn_dropout(torch.cat([context, feats], 1))
return output