99
1010Channels last memory format is an alternative way of ordering NCHW tensors in memory preserving dimensions ordering. Channels last tensors ordered in such a way that channels become the densest dimension (aka storing images pixel-per-pixel).
1111
12- For example, classic (contiguous) storage of NCHW tensor (in our case it is two 2x2 images with 3 color channels) look like this:
12+ For example, classic (contiguous) storage of NCHW tensor (in our case it is two 4x4 images with 3 color channels) look like this:
1313
1414.. figure:: /_static/img/classic_memory_format.png
1515 :alt: classic_memory_format
3737######################################################################
3838# Classic PyTorch contiguous tensor
3939import torch
40+
4041N , C , H , W = 10 , 3 , 32 , 32
4142x = torch .empty (N , C , H , W )
42- print (x .stride ()) # Ouputs: (3072, 1024, 32, 1)
43+ print (x .stride ()) # Ouputs: (3072, 1024, 32, 1)
4344
4445######################################################################
4546# Conversion operator
4647x = x .to (memory_format = torch .channels_last )
47- print (x .shape ) # Outputs: (10, 3, 32, 32) as dimensions order preserved
48- print (x .stride ()) # Outputs: (3072, 1, 96, 3)
48+ print (x .shape ) # Outputs: (10, 3, 32, 32) as dimensions order preserved
49+ print (x .stride ()) # Outputs: (3072, 1, 96, 3)
4950
5051######################################################################
5152# Back to contiguous
5253x = x .to (memory_format = torch .contiguous_format )
53- print (x .stride ()) # Outputs: (3072, 1024, 32, 1)
54+ print (x .stride ()) # Outputs: (3072, 1024, 32, 1)
5455
5556######################################################################
5657# Alternative option
5758x = x .contiguous (memory_format = torch .channels_last )
58- print (x .stride ()) # Ouputs: (3072, 1, 96, 3)
59+ print (x .stride ()) # Ouputs: (3072, 1, 96, 3)
5960
6061######################################################################
6162# Format checks
62- print (x .is_contiguous (memory_format = torch .channels_last )) # Ouputs: True
63+ print (x .is_contiguous (memory_format = torch .channels_last )) # Ouputs: True
6364
6465######################################################################
6566# There are minor difference between the two APIs ``to`` and
8182# sizes are 1 in order to properly represent the intended memory
8283# format
8384special_x = torch .empty (4 , 1 , 4 , 4 )
84- print (special_x .is_contiguous (memory_format = torch .channels_last )) # Ouputs: True
85- print (special_x .is_contiguous (memory_format = torch .contiguous_format )) # Ouputs: True
85+ print (special_x .is_contiguous (memory_format = torch .channels_last )) # Ouputs: True
86+ print (special_x .is_contiguous (memory_format = torch .contiguous_format )) # Ouputs: True
8687
8788######################################################################
8889# Same thing applies to explicit permutation API ``permute``. In
99100######################################################################
100101# Create as channels last
101102x = torch .empty (N , C , H , W , memory_format = torch .channels_last )
102- print (x .stride ()) # Ouputs: (3072, 1, 96, 3)
103+ print (x .stride ()) # Ouputs: (3072, 1, 96, 3)
103104
104105######################################################################
105106# ``clone`` preserves memory format
106107y = x .clone ()
107- print (y .stride ()) # Ouputs: (3072, 1, 96, 3)
108+ print (y .stride ()) # Ouputs: (3072, 1, 96, 3)
108109
109110######################################################################
110111# ``to``, ``cuda``, ``float`` ... preserves memory format
111112if torch .cuda .is_available ():
112113 y = x .cuda ()
113- print (y .stride ()) # Ouputs: (3072, 1, 96, 3)
114+ print (y .stride ()) # Ouputs: (3072, 1, 96, 3)
114115
115116######################################################################
116117# ``empty_like``, ``*_like`` operators preserves memory format
117118y = torch .empty_like (x )
118- print (y .stride ()) # Ouputs: (3072, 1, 96, 3)
119+ print (y .stride ()) # Ouputs: (3072, 1, 96, 3)
119120
120121######################################################################
121122# Pointwise operators preserves memory format
122123z = x + y
123- print (z .stride ()) # Ouputs: (3072, 1, 96, 3)
124+ print (z .stride ()) # Ouputs: (3072, 1, 96, 3)
124125
125126######################################################################
126127# Conv, Batchnorm modules using cudnn backends support channels last
132133
133134if torch .backends .cudnn .version () >= 7603 :
134135 model = torch .nn .Conv2d (8 , 4 , 3 ).cuda ().half ()
135- model = model .to (memory_format = torch .channels_last ) # Module parameters need to be channels last
136+ model = model .to (memory_format = torch .channels_last ) # Module parameters need to be channels last
136137
137138 input = torch .randint (1 , 10 , (2 , 8 , 4 , 4 ), dtype = torch .float32 , requires_grad = True )
138139 input = input .to (device = "cuda" , memory_format = torch .channels_last , dtype = torch .float16 )
139140
140141 out = model (input )
141- print (out .is_contiguous (memory_format = torch .channels_last )) # Ouputs: True
142+ print (out .is_contiguous (memory_format = torch .channels_last )) # Ouputs: True
142143
143144######################################################################
144145# When input tensor reaches a operator without channels last support,
195196
196197######################################################################
197198# Passing ``--channels-last true`` allows running a model in Channels last format with observed 22% perf gain.
198- #
199+ #
199200# ``python main_amp.py -a resnet50 --b 200 --workers 16 --opt-level O2 --channels-last true ./data``
200201
201202# opt_level = O2
250251#
251252
252253# Need to be done once, after model initialization (or load)
253- model = model .to (memory_format = torch .channels_last ) # Replace with your model
254+ model = model .to (memory_format = torch .channels_last ) # Replace with your model
254255
255256# Need to be done for every input
256- input = input .to (memory_format = torch .channels_last ) # Replace with your input
257+ input = input .to (memory_format = torch .channels_last ) # Replace with your input
257258output = model (input )
258259
259260#######################################################################
271272# operatos in your model that does not support channels last, if you
272273# want to improve the performance of converted model.
273274#
274- # That means you need to verify the list of used operators
275- # against supported operators list https://github.com/pytorch/pytorch/wiki/Operators-with-Channels-Last-support,
275+ # That means you need to verify the list of used operators
276+ # against supported operators list https://github.com/pytorch/pytorch/wiki/Operators-with-Channels-Last-support,
276277# or introduce memory format checks into eager execution mode and run your model.
277278#
278279# After running the code below, operators will raise an exception if the output of the
@@ -290,13 +291,13 @@ def contains_cl(args):
290291 return False
291292
292293
293- def print_inputs (args , indent = '' ):
294+ def print_inputs (args , indent = "" ):
294295 for t in args :
295296 if isinstance (t , torch .Tensor ):
296297 print (indent , t .stride (), t .shape , t .device , t .dtype )
297298 elif isinstance (t , list ) or isinstance (t , tuple ):
298299 print (indent , type (t ))
299- print_inputs (list (t ), indent = indent + ' ' )
300+ print_inputs (list (t ), indent = indent + " " )
300301 else :
301302 print (indent , t )
302303
@@ -311,32 +312,38 @@ def check_cl(*args, **kwargs):
311312 except Exception as e :
312313 print ("`{}` inputs are:" .format (name ))
313314 print_inputs (args )
314- print (' -------------------' )
315+ print (" -------------------" )
315316 raise e
316317 failed = False
317318 if was_cl :
318319 if isinstance (result , torch .Tensor ):
319320 if result .dim () == 4 and not result .is_contiguous (memory_format = torch .channels_last ):
320- print ("`{}` got channels_last input, but output is not channels_last:" .format (name ),
321- result .shape , result .stride (), result .device , result .dtype )
321+ print (
322+ "`{}` got channels_last input, but output is not channels_last:" .format (name ),
323+ result .shape ,
324+ result .stride (),
325+ result .device ,
326+ result .dtype ,
327+ )
322328 failed = True
323329 if failed and True :
324330 print ("`{}` inputs are:" .format (name ))
325331 print_inputs (args )
326- raise Exception (
327- 'Operator `{}` lost channels_last property' .format (name ))
332+ raise Exception ("Operator `{}` lost channels_last property" .format (name ))
328333 return result
334+
329335 return check_cl
330336
337+
331338old_attrs = dict ()
332339
340+
333341def attribute (m ):
334342 old_attrs [m ] = dict ()
335343 for i in dir (m ):
336344 e = getattr (m , i )
337- exclude_functions = ['is_cuda' , 'has_names' , 'numel' ,
338- 'stride' , 'Tensor' , 'is_contiguous' , '__class__' ]
339- if i not in exclude_functions and not i .startswith ('_' ) and '__call__' in dir (e ):
345+ exclude_functions = ["is_cuda" , "has_names" , "numel" , "stride" , "Tensor" , "is_contiguous" , "__class__" ]
346+ if i not in exclude_functions and not i .startswith ("_" ) and "__call__" in dir (e ):
340347 try :
341348 old_attrs [m ][i ] = e
342349 setattr (m , i , check_wrapper (e ))
@@ -352,16 +359,16 @@ def attribute(m):
352359
353360######################################################################
354361# If you found an operator that doesn't support channels last tensors
355- # and you want to contribute, feel free to use following developers
362+ # and you want to contribute, feel free to use following developers
356363# guide https://github.com/pytorch/pytorch/wiki/Writing-memory-format-aware-operators.
357364#
358365
359366######################################################################
360367# Code below is to recover the attributes of torch.
361368
362369for (m , attrs ) in old_attrs .items ():
363- for (k ,v ) in attrs .items ():
364- setattr (m , k , v )
370+ for (k , v ) in attrs .items ():
371+ setattr (m , k , v )
365372
366373######################################################################
367374# Work to do
0 commit comments