1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
1911
1912
1913
1914
1915
1916
1917
1918
1919
1920
1921
1922
1923
1924
1925
1926
1927
1928
1929
1930
1931
1932
1933
1934
1935
1936
1937
1938
1939
1940
1941
1942
1943
1944
1945
1946
1947
1948
1949
1950
1951
1952
1953
1954
1955
1956
1957
1958
1959
1960
1961
1962
1963
1964
1965
1966
1967
1968
1969
1970
1971
1972
1973
1974
1975
1976
1977
1978
1979
1980
1981
1982
1983
1984
1985
1986
1987
1988
1989
1990
1991
1992
1993
1994
1995
1996
1997
1998
1999
2000
2001
2002
2003
2004
2005
2006
2007
2008
2009
2010
2011
2012
2013
2014
2015
2016
2017
2018
2019
2020
2021
2022
2023
2024
2025
2026
2027
2028
2029
2030
2031
2032
2033
2034
2035
2036
2037
2038
2039
2040
2041
2042
2043
2044
2045
2046
2047
2048
2049
2050
2051
2052
2053
2054
2055
2056
2057
2058
2059
2060
2061
2062
2063
2064
2065
2066
2067
2068
2069
2070
2071
2072
2073
|
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<a href=\"https://www.nvidia.com/dli\"> <img src=\"images/DLI Header.png\" alt=\"Header\" style=\"width: 400px;\"/> </a>"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Effective Use of the Memory Subsystem"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now that you can write correct CUDA kernels, and understand the importance of launching grids that give the GPU sufficient opportunity to hide latency, you are going to learn techniques to effectively utilize GPU memory subsystems. These techniques are widely applicable to a variety of CUDA applications, and some of the most important when it comes time to make your CUDA code go fast.\n",
"\n",
"You are going to begin by learning about memory coalescing. To challenge your ability to reason about memory coalescing, and to expose important details relevent to many CUDA applications, you will then learn about 2-dimensional grids and thread blocks. Next you will learn about a very fast, user-controlled, on-demand memory space called shared memory, and will use shared memory to facilitate memory coalescing where it would not have otherwise been possible. Finally, you will learn about shared memory bank conflicts, which can spoil the performance possibilities of using shared memory, and a technique to address them."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Objectives"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"By the time you complete this section, you will be able to:\n",
"* Write CUDA kernels that benefit from coalesced memory access patterns.\n",
"* Work with multi-dimensional grids and thread blocks.\n",
"* Use shared memory to coordinate threads within a block.\n",
"* Use shared memory to facilitate coalesced memory access patterns.\n",
"* Resolve shared memory bank conflicts."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## The Problem: Uncoalesced Memory Access Hurts Performance"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Before you learn the details about what **coalesced memory access** is, run the following cells to observe the performance implications for a seemingly trivial change to the data access pattern within a kernel."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Imports"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"import numpy as np\n",
"from numba import cuda"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Data Creation"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In this cell we define `n` and create a grid with threads equal to `n`. We also create an output vector with length `n`. For the inputs we create vectors of size `stride * n` for reasons that will be made clear below:"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"n = 1024*1024 # 1M\n",
"\n",
"threads_per_block = 1024\n",
"blocks = int(n / threads_per_block)\n",
"\n",
"stride = 16\n",
"\n",
"# Input Vectors of length stride * n\n",
"a = np.ones(stride * n).astype(np.float32)\n",
"b = a.copy().astype(np.float32)\n",
"\n",
"# Output Vector\n",
"out = np.zeros(n).astype(np.float32)\n",
"\n",
"d_a = cuda.to_device(a)\n",
"d_b = cuda.to_device(b)\n",
"d_out = cuda.to_device(out)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Kernel Definition"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In `add_experiment`, every thread in the grid will add an item in `a`, and an item in `b` and write the result to `out`. The kernel has been written such that we can pass a `coalesced` value of either `True` or `False` to affect how it indexes into the `a` and `b` vectors. You will see the performance comparison of the two modes below."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"@cuda.jit\n",
"def add_experiment(a, b, out, stride, coalesced):\n",
" i = cuda.grid(1)\n",
" # The above line is equivalent to\n",
" # i = cuda.blockIdx.x * cuda.blockDim.x + cuda.threadIdx.x\n",
" if coalesced == True:\n",
" out[i] = a[i] + b[i]\n",
" else:\n",
" out[i] = a[stride*i] + b[stride*i]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Launch Kernel Using Coalesced Access"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Here we pass `True` as the `coalesced` value, and observe the performance of the kernel over several runs:"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"227 µs ± 73.6 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
]
}
],
"source": [
"%timeit add_experiment[blocks, threads_per_block](d_a, d_b, d_out, stride, True); cuda.synchronize"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Here we make sure the kernel ran as expected:"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"result = d_out.copy_to_host()\n",
"truth = a[:n] + b[:n]"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"scrolled": true
},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"np.array_equal(result, truth)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Launch Kernel Using Uncoalesced Access"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In this cell we pass `False`, to observe the perfomance of the uncoalesced data access pattern for `add_experiment`:"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"540 µs ± 10.9 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n"
]
}
],
"source": [
"%timeit add_experiment[blocks, threads_per_block](d_a, d_b, d_out, stride, False); cuda.synchronize"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Here we make sure the kernel ran as expected:"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"result = d_out.copy_to_host()\n",
"truth = a[::stride] + b[::stride]"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"np.array_equal(result, truth)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Results"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The performance of the uncoalesced data access pattern was far worse. Now you will learn why, and how to think about data access patterns in your kernels to obtain high performing kernels."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Presentation: Global Memory Coalescing"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Execute the following cell to load the slides, then click on \"Start Slide Show\" to make them full screen."
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"scrolled": true
},
"outputs": [
{
"data": {
"text/html": [
"\n",
" <iframe\n",
" width=\"800\"\n",
" height=\"450\"\n",
" src=\"https://view.officeapps.live.com/op/view.aspx?src=https://developer.download.nvidia.com/training/courses/C-AC-02-V1/coalescing-v3.pptx\"\n",
" frameborder=\"0\"\n",
" allowfullscreen\n",
" ></iframe>\n",
" "
],
"text/plain": [
"<IPython.lib.display.IFrame at 0x7f9086df57b8>"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from IPython.display import IFrame\n",
"IFrame('https://view.officeapps.live.com/op/view.aspx?src=https://developer.download.nvidia.com/training/courses/C-AC-02-V1/coalescing-v3.pptx', 800, 450)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"> _**Footnote**: for additional details about global memory segment size across a variety of devices, and with regards to caching, see [The CUDA Best Practices Guide](https://docs.nvidia.com/cuda/cuda-c-best-practices-guide/index.html#coalesced-access-to-global-memory)._"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Exercise: Column and Row Sums"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"For this exercise you will be asked to write a column sums kernel that uses fully coalesced memory access patterns. To begin you will observe the performance of a row sums kernel that makes uncoalesced memory accesses."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Row Sums"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Imports**"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"import numpy as np\n",
"from numba import cuda"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Data Creation**"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In this cell we create an input matrix, as well as a vector for storing the solution, and transfer each of them to the device. We also define the grid and block dimensions to be used when we launch the kernel below. We set an arbitrary row of data to some arbitrary value to facilitate checking for correctness below."
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"n = 16384 # matrix side size\n",
"threads_per_block = 256\n",
"blocks = int(n / threads_per_block)\n",
"\n",
"# Input Matrix\n",
"a = np.ones(n*n).reshape(n, n).astype(np.float32)\n",
"# Here we set an arbitrary row to an arbitrary value to facilitate a check for correctness below.\n",
"a[3] = 9\n",
"\n",
"# Output vector\n",
"sums = np.zeros(n).astype(np.float32)\n",
"\n",
"d_a = cuda.to_device(a)\n",
"d_sums = cuda.to_device(sums)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"** Kernel Definition**"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"`row_sums` will use each thread to iterate over a row of data, summing it, and then store its row sum in `sums`."
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"@cuda.jit\n",
"def row_sums(a, sums, n):\n",
" idx = cuda.grid(1)\n",
" sum = 0.0\n",
" \n",
" for i in range(n):\n",
" # Each thread will sum a row of `a`\n",
" sum += a[idx][i]\n",
" \n",
" sums[idx] = sum"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Row Sums Performance**"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"11.6 ms ± 215 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
]
}
],
"source": [
"%timeit row_sums[blocks, threads_per_block](d_a, d_sums, n); cuda.synchronize()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Check for Correctness**"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"result = d_sums.copy_to_host()\n",
"truth = a.sum(axis=1)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"np.array_equal(truth, result)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Column Sums"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Imports**"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"import numpy as np\n",
"from numba import cuda"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Data Creation**"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In this cell we create an input matrix, as well as a vector for storing the solution, and transfer each of them to the device. We also define the grid and block dimensions to be used when we launch the kernel below. We set an arbitrary column of data to some arbitrary value to facilitate checking for correctness below."
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"n = 16384 # matrix side size\n",
"threads_per_block = 256\n",
"blocks = int(n / threads_per_block)\n",
"\n",
"a = np.ones(n*n).reshape(n, n).astype(np.float32)\n",
"# Here we set an arbitrary column to an arbitrary value to facilitate a check for correctness below.\n",
"a[:, 3] = 9\n",
"sums = np.zeros(n).astype(np.float32)\n",
"\n",
"d_a = cuda.to_device(a)\n",
"d_sums = cuda.to_device(sums)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"** Kernel Definition**"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"`col_sums` will use each thread to iterate over a column of data, summing it, and then store its column sum in `sums`. Complete the kernel definition to accomplish this. If you get stuck, feel free to refer to [the solution](../edit/solutions/col_sums_solution.py)."
]
},
{
"cell_type": "code",
"execution_count": 52,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"@cuda.jit\n",
"def col_sums(a, sums, ds):\n",
" idx = cuda.grid(1)\n",
" stride = cuda.gridsize(1)\n",
" sum = 0\n",
" for i in range(stride):\n",
" sum += a[i][idx]\n",
" \n",
" sums[idx] = sum"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Check Performance**"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Assuming you have written `col_sums` to use coalesced access patterns, you should see a significant (almost 2x) speed up compared to the uncoalesced `row_sums` you ran above:"
]
},
{
"cell_type": "code",
"execution_count": 53,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"7.88 ms ± 3.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
]
}
],
"source": [
"%timeit col_sums[blocks, threads_per_block](d_a, d_sums, n); cuda.synchronize()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Check Correctness**"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Confirm your kernel is working as expected."
]
},
{
"cell_type": "code",
"execution_count": 50,
"metadata": {},
"outputs": [],
"source": [
"col_sums[blocks, threads_per_block](d_a, d_sums, n)\n",
"cuda.synchronize()\n",
"result = d_sums.copy_to_host()\n",
"truth = a.sum(axis=0)"
]
},
{
"cell_type": "code",
"execution_count": 51,
"metadata": {
"scrolled": true
},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 51,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"np.array_equal(truth, result)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 2 and 3 Dimensional Blocks and Grids"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Both grids and blocks can be configured to contain a 2 or 3 dimensional collection of blocks or threads, respectively. This is done mostly as a matter of convenience for programmers who often work with 2 or 3 dimensional datasets. Here is a very trivial example to highlight the syntax. You may need to read *both* the kernel definition and its launch before the concept makes sense."
]
},
{
"cell_type": "code",
"execution_count": 54,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"import numpy as np\n",
"from numba import cuda"
]
},
{
"cell_type": "code",
"execution_count": 57,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"A = np.zeros((4,4)) # A 4x4 Matrix of 0's\n",
"d_A = cuda.to_device(A)\n",
"\n",
"# Here we create a 2D grid with 4 blocks in a 2x2 structure, each with 4 threads in a 2x2 structure\n",
"# by using a Python tuple to signify grid and block dimensions.\n",
"blocks = (2, 2)\n",
"threads_per_block = (2, 2)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This kernel will take an input matrix of 0s and write to each of its elements, its (x,y) coordinates within the grid in the format of `X.Y`:"
]
},
{
"cell_type": "code",
"execution_count": 58,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"@cuda.jit\n",
"def get_2D_indices(A):\n",
" # By passing `2`, we get the thread's unique x and y coordinates in the 2D grid\n",
" x, y = cuda.grid(2)\n",
" # The above is equivalent to the following 2 lines of code:\n",
" # x = cuda.blockIdx.x * cuda.blockDim.x + cuda.threadIdx.x\n",
" # y = cuda.blockIdx.y * cuda.blockDim.y + cuda.threadIdx.y\n",
" \n",
" # Write the x index followed by a decimal and the y index.\n",
" A[x][y] = x + y / 10"
]
},
{
"cell_type": "code",
"execution_count": 59,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"get_2D_indices[blocks, threads_per_block](d_A)"
]
},
{
"cell_type": "code",
"execution_count": 60,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[0. , 0.1, 0.2, 0.3],\n",
" [1. , 1.1, 1.2, 1.3],\n",
" [2. , 2.1, 2.2, 2.3],\n",
" [3. , 3.1, 3.2, 3.3]])"
]
},
"execution_count": 60,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"result = d_A.copy_to_host()\n",
"result"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Exercise: Coalesced 2-Dimensional Matrix Add"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Imports"
]
},
{
"cell_type": "code",
"execution_count": 62,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"import numpy as np\n",
"from numba import cuda"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Data Creation"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In this cell we define 2048x2048 elmement input matrices `a` and `b`, as well as a 2048x2048 0-initialized output matrix. We copy these matrices to the device.\n",
"\n",
"We also define the 2-dimensional block and grid dimensions to be used below. Note that we are creating a grid with the same number of total threads as there are input and output elements, such that each thread in the grid will calculate the sum for a single element in the output matrix."
]
},
{
"cell_type": "code",
"execution_count": 63,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"n = 2048*2048 # 4M\n",
"\n",
"# 2D blocks\n",
"threads_per_block = (32, 32)\n",
"# 2D grid\n",
"blocks = (64, 64)\n",
"\n",
"# 2048x2048 input matrices\n",
"a = np.arange(n).reshape(2048,2048).astype(np.float32)\n",
"b = a.copy().astype(np.float32)\n",
"\n",
"# 2048x2048 0-initialized output matrix\n",
"out = np.zeros_like(a).astype(np.float32)\n",
"\n",
"d_a = cuda.to_device(a)\n",
"d_b = cuda.to_device(b)\n",
"d_out = cuda.to_device(out)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 2D Matrix Add"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Your job is to complete the TODOs in `matrix_add` to correctly sum `a` and `b` into `out`. As a challenge to your understanding of coalesced access patterns, `matrix_add` will accept a `coalesced` boolean indicating whether the access patterns should be coalesced or not. Both modes (coalesced and uncoalesced) should produce correct results, however, you should observe significant speedups below when running with `coalesced` set to `True`.\n",
"\n",
"If you get stuck, feel free to check out [the solution](../edit/solutions/matrix_add_solution.py)."
]
},
{
"cell_type": "code",
"execution_count": 66,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"@cuda.jit\n",
"def matrix_add(a, b, out, coalesced):\n",
" x, y = cuda.grid(2)\n",
" \n",
" if coalesced == True:\n",
" out[y][x] = a[y][x] + b[y][x]\n",
" else:\n",
" out[x][y] = a[x][y] + b[x][y]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Check Performance"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Run both cells below to launch `matrix_add` with both the coalesced and uncoalesced access patterns you wrote into it, and observe the performance difference. Additional cells have been provided to confirm the correctness of your kernel."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Coalesced**"
]
},
{
"cell_type": "code",
"execution_count": 68,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"203 µs ± 13.4 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n"
]
}
],
"source": [
"%timeit matrix_add[blocks, threads_per_block](d_a, d_b, d_out, True); cuda.synchronize"
]
},
{
"cell_type": "code",
"execution_count": 69,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"result = d_out.copy_to_host()\n",
"truth = a+b"
]
},
{
"cell_type": "code",
"execution_count": 70,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 70,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"np.array_equal(result, truth)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Uncoalesced**"
]
},
{
"cell_type": "code",
"execution_count": 71,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"586 µs ± 1.21 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n"
]
}
],
"source": [
"%timeit matrix_add[blocks, threads_per_block](d_a, d_b, d_out, False); cuda.synchronize"
]
},
{
"cell_type": "code",
"execution_count": 72,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"result = d_out.copy_to_host()\n",
"truth = a+b"
]
},
{
"cell_type": "code",
"execution_count": 73,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 73,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"np.array_equal(result, truth)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Shared Memory"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"So far we have been differentiating between host and device memory, as if device memory were a single kind of memory. But in fact, CUDA has an even more fine-grained [memory hierarchy](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#memory-hierarchy). The device memory we have been utilizing thus far is called **global memory** which is available to any thread or block on the device, can persist for the lifetime of the application, and is a relatively large memory space.\n",
"\n",
"We will now discuss how to utilize a region of on-chip device memory called **shared memory**. Shared memory is a programmer defined cache of limited size that [depends on the GPU](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#compute-capabilities) being used and is **shared** between all threads in a block. It is a scarce resource, cannot be accessed by threads outside of the block where it was allocated, and does not persist after a kernel finishes executing. Shared memory however has a much higher bandwidth than global memory and can be used to great effect in many kernels, especially to optimize performance.\n",
"\n",
"Here are a few common use cases for shared memory:\n",
"\n",
" * Caching memory read from global memory that will need to be read multiple times within a block.\n",
" * Buffering output from threads so it can be coalesced before writing it back to global memory.\n",
" * Staging data for scatter/gather operations within a block."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Shared Memory Syntax"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Numba provides [functions](https://numba.pydata.org/numba-doc/dev/cuda/memory.html#shared-memory-and-thread-synchronization) for allocating shared memory as well as for synchronizing between threads in a block, which is often necessary after parallel threads read from or write to shared memory.\n",
"\n",
"When declaring shared memory, you provide the shape of the shared array, as well as its type, using a [Numba type](https://numba.pydata.org/numba-doc/dev/reference/types.html#numba-types). **The shape of the array must be a constant value**, and therefore, you cannot use arguments passed into the function, or, provided variables like `numba.cuda.blockDim.x`, or the calculated values of `cuda.griddim`. Here is a convoluted example to demonstrate the syntax with comments pointing out the movement from host memory to global device memory, to shared memory, back to global device memory, and finally back to host memory:"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Imports**"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We will use `numba.types` to define the types of values in shared memory."
]
},
{
"cell_type": "code",
"execution_count": 74,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"import numpy as np\n",
"from numba import types, cuda"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Swap Elements Using Shared Memory**"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The following kernel takes an input vector, where each thread will first write one element of the vector to shared memory, and then, after syncing such that all elements have been written to shared memory, will write one element out of shared memory into the swapped output vector.\n",
"\n",
"Worth noting is that each thread will be writing a swapped value from shared memory that was written into shared memory by another thread."
]
},
{
"cell_type": "code",
"execution_count": 75,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"@cuda.jit\n",
"def swap_with_shared(vector, swapped):\n",
" # Allocate a 4 element vector containing int32 values in shared memory.\n",
" temp = cuda.shared.array(4, dtype=types.int32)\n",
" \n",
" idx = cuda.grid(1)\n",
" \n",
" # Move an element from global memory into shared memory\n",
" temp[idx] = vector[idx]\n",
" \n",
" # cuda.syncthreads will force all threads in the block to synchronize here, which is necessary because...\n",
" cuda.syncthreads()\n",
" #...the following operation is reading an element written to shared memory by another thread.\n",
" \n",
" # Move an element from shared memory back into global memory\n",
" swapped[idx] = temp[3 - cuda.threadIdx.x] # swap elements"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Data Creation**"
]
},
{
"cell_type": "code",
"execution_count": 76,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"vector = np.arange(4).astype(np.int32)\n",
"swapped = np.zeros_like(vector)\n",
"\n",
"# Move host memory to device (global) memory\n",
"d_vector = cuda.to_device(vector)\n",
"d_swapped = cuda.to_device(swapped)"
]
},
{
"cell_type": "code",
"execution_count": 77,
"metadata": {
"scrolled": true
},
"outputs": [
{
"data": {
"text/plain": [
"array([0, 1, 2, 3], dtype=int32)"
]
},
"execution_count": 77,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"vector"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"** Run Kernel**"
]
},
{
"cell_type": "code",
"execution_count": 78,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"swap_with_shared[1, 4](d_vector, d_swapped)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Check Results**"
]
},
{
"cell_type": "code",
"execution_count": 79,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([3, 2, 1, 0], dtype=int32)"
]
},
"execution_count": 79,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Move device (global) memory back to the host\n",
"result = d_swapped.copy_to_host()\n",
"result"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Presentation: Shared Memory for Memory Coalescing"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Execute the following cell to load the slides, then click on \"Start Slide Show\" to make them full screen."
]
},
{
"cell_type": "code",
"execution_count": 80,
"metadata": {
"scrolled": true
},
"outputs": [
{
"data": {
"text/html": [
"\n",
" <iframe\n",
" width=\"800\"\n",
" height=\"450\"\n",
" src=\"https://view.officeapps.live.com/op/view.aspx?src=https://developer.download.nvidia.com/training/courses/C-AC-02-V1/shared_coalescing.pptx\"\n",
" frameborder=\"0\"\n",
" allowfullscreen\n",
" ></iframe>\n",
" "
],
"text/plain": [
"<IPython.lib.display.IFrame at 0x7f9086acc438>"
]
},
"execution_count": 80,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from IPython.display import IFrame\n",
"IFrame('https://view.officeapps.live.com/op/view.aspx?src=https://developer.download.nvidia.com/training/courses/C-AC-02-V1/shared_coalescing.pptx', 800, 450)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Excercise: Used Shared Memory for Coalesced Reads and Writes With Matrix Transpose"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In this exercise you will implement what was just demonstrated in the presentation by writing a matrix transpose kernel which, using shared memory, makes coalesced reads and writes to the output matrix in global memory."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Coalesced Reads, Uncoalesced Writes"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"As reference, and for performance comparison, here is a naive matrix transpose kernel that makes coalesced reads from input, but uncoalesced writes to output."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Imports**"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"from numba import cuda\n",
"import numpy as np"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Data Creation**"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Here we create a 4096x4096 input matrix `a` as well as a 4096x4096 output matrix `transposed`, and copy them to the device.\n",
"\n",
"We also define a 2-dimensional grid with 2-dimensional blocks to be used below. Note that we have created a grid with a total number of threads equal to the number of elments in the input matrix."
]
},
{
"cell_type": "code",
"execution_count": 81,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"n = 4096*4096 # 16M\n",
"\n",
"# 2D blocks\n",
"threads_per_block = (32, 32)\n",
"#2D grid\n",
"blocks = (128, 128)\n",
"\n",
"# 4096x4096 input and output matrices\n",
"a = np.arange(n).reshape((4096,4096)).astype(np.float32)\n",
"transposed = np.zeros_like(a).astype(np.float32)\n",
"\n",
"d_a = cuda.to_device(a)\n",
"d_transposed = cuda.to_device(transposed)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Naive Matrix Transpose Kernel**"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This kernel correctly transposes `a`, writing the transposition to `transposed`. It makes reads from `a` in a coalesced fashion, however, its writes to `transposed` are uncoalesced."
]
},
{
"cell_type": "code",
"execution_count": 82,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"@cuda.jit\n",
"def transpose(a, transposed):\n",
" x, y = cuda.grid(2)\n",
"\n",
" transposed[x][y] = a[y][x]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Check Performance**"
]
},
{
"cell_type": "code",
"execution_count": 83,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1.59 ms ± 25.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
]
}
],
"source": [
"%timeit transpose[blocks, threads_per_block](d_a, d_transposed); cuda.synchronize()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Check Correctness**"
]
},
{
"cell_type": "code",
"execution_count": 84,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"result = d_transposed.copy_to_host()\n",
"expected = a.T"
]
},
{
"cell_type": "code",
"execution_count": 85,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 85,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"np.array_equal(result, expected)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Refactor for Coalesced Reads and Writes"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Your job will be to refactor the `transpose` kernel to use shared memory and make both reads to and writes from global memory in a coalesced fashion."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Imports**"
]
},
{
"cell_type": "code",
"execution_count": 86,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"import numpy as np\n",
"from numba import cuda, types as numba_types"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Data Creation**"
]
},
{
"cell_type": "code",
"execution_count": 87,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"n = 4096*4096 # 16M\n",
"\n",
"# 2D blocks\n",
"threads_per_block = (32, 32)\n",
"#2D grid\n",
"blocks = (128, 128)\n",
"\n",
"# 4096x4096 input and output matrices\n",
"a = np.arange(n).reshape((4096,4096)).astype(np.float32)\n",
"transposed = np.zeros_like(a).astype(np.float32)\n",
"\n",
"d_a = cuda.to_device(a)\n",
"d_transposed = cuda.to_device(transposed)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Write a Transpose Kernel that Uses Shared Memory**"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Complete the TODOs inside the `tile_transpose` kernel definition.\n",
"\n",
"If you get stuck, feel free to check out [the solution](../edit/solutions/tile_transpose_solution.py)."
]
},
{
"cell_type": "code",
"execution_count": 91,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"@cuda.jit\n",
"def tile_transpose(a, transposed):\n",
" # `tile_transpose` assumes it is launched with a 32x32 block dimension,\n",
" # and that `a` is a multiple of these dimensions.\n",
" \n",
" # 1) Create 32x32 shared memory array.\n",
" \n",
" temp = cuda.shared.array((32,32), dtype=types.int32)\n",
"\n",
" # Compute offsets into global input array. Recall for coalesced access we want to map threadIdx.x increments to\n",
" # the fastest changing index in the data, i.e. the column in our array.\n",
" # Note: `a_col` and `a_row` are already correct.\n",
" a_col = cuda.blockIdx.x * cuda.blockDim.x + cuda.threadIdx.x\n",
" a_row = cuda.blockIdx.y * cuda.blockDim.y + cuda.threadIdx.y\n",
" \n",
" # 2) Make coalesced read from global memory (using grid indices)\n",
" # into shared memory array (using thread indices).\n",
" \n",
" temp[cuda.threadIdx.y, cuda.threadIdx.x] = a[a_row, a_col]\n",
"\n",
" # 3) Wait for all threads in the block to finish updating shared memory.\n",
" \n",
" cuda.syncthreads()\n",
" \n",
" # 4) Calculate transposed location for the shared memory array tile\n",
" # to be written back to global memory. Note that blockIdx.y*blockDim.y \n",
" # and blockIdx.x* blockDim.x are swapped (because we want to write to the\n",
" # transpose locations), but we want to keep access coalesced, so match up the\n",
" # threadIdx.x to the fastest changing index, i.e. the column./\n",
" # Note: `t_col` and `t_row` are already correct.\n",
" t_col = cuda.blockIdx.y * cuda.blockDim.y + cuda.threadIdx.x\n",
" t_row = cuda.blockIdx.x * cuda.blockDim.x + cuda.threadIdx.y\n",
"\n",
" # 5) Write from shared memory (using thread indices)\n",
" # back to global memory (using grid indices)\n",
" # transposing each element within the shared memory array.\n",
" \n",
" transposed[t_row, t_col] = temp[cuda.threadIdx.x, cuda.threadIdx.y]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Check Performance**"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Check the performance of your refactored transpose kernel. You should see a speedup compared to the baseline transpose performance above."
]
},
{
"cell_type": "code",
"execution_count": 92,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1.09 ms ± 60.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
]
}
],
"source": [
"%timeit tile_transpose[blocks, threads_per_block](d_a, d_transposed); cuda.synchronize()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Check Correctness**"
]
},
{
"cell_type": "code",
"execution_count": 93,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"result = d_transposed.copy_to_host()\n",
"expected = a.T"
]
},
{
"cell_type": "code",
"execution_count": 94,
"metadata": {
"scrolled": true
},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 94,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"np.array_equal(result, expected)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Why Such a Small Improvement?"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"While this is a significant speedup for only a few lines of code, but you might think that the performance improvement is not as stark as you expected based on earlier performance improvements to use coalesced access patterns. There are 2 main reasons for this:\n",
"\n",
"1. The naive transpose kernel was making coalesced reads, so, your refactored version only optimized half of the global memory access throughout the execution of the kernel.\n",
"2. Your code as written suffers from something called shared memory bank conflicts, a topic to which we will now turn our attention."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Presentation: Memory Bank Conflicts"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Execute the following cell to load the slides, then click on \"Start Slide Show\" to make them full screen."
]
},
{
"cell_type": "code",
"execution_count": 95,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
" <iframe\n",
" width=\"800\"\n",
" height=\"450\"\n",
" src=\"https://view.officeapps.live.com/op/view.aspx?src=https://developer.download.nvidia.com/training/courses/C-AC-02-V1/bank_conflicts.pptx\"\n",
" frameborder=\"0\"\n",
" allowfullscreen\n",
" ></iframe>\n",
" "
],
"text/plain": [
"<IPython.lib.display.IFrame at 0x7f908e3479e8>"
]
},
"execution_count": 95,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from IPython.display import IFrame\n",
"IFrame('https://view.officeapps.live.com/op/view.aspx?src=https://developer.download.nvidia.com/training/courses/C-AC-02-V1/bank_conflicts.pptx', 800, 450)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Assessment: Resolve Memory Bank Conflicts"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"As a final exercise, and to get credit towards a certificate in the course for this final section of the workshop, you will refactor the transpose kernel utilizing shared memory to be shared memory bank conflict free."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Imports"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"import numpy as np\n",
"from numba import cuda, types as numba_types"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Data Creation"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"n = 4096*4096 # 16M\n",
"threads_per_block = (32, 32)\n",
"blocks = (128, 128)\n",
"\n",
"a = np.arange(n).reshape((4096,4096)).astype(np.float32)\n",
"transposed = np.zeros_like(a).astype(np.float32)\n",
"\n",
"d_a = cuda.to_device(a)\n",
"d_transposed = cuda.to_device(transposed)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Make the Kernel Bank Conflict Free"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The `tile_transpose_conflict_free` kernel is a working matrix transpose kernel which utilizes shared memory so that both reads from and writes to global memory are coalesced. Your job is to refactor the kernel so that it does not suffer from memory bank conflicts.\n",
"\n",
"**Note:** Because this final exercise counts towards certification in the course, a solution will not be provided."
]
},
{
"cell_type": "code",
"execution_count": 96,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"@cuda.jit\n",
"def tile_transpose_conflict_free(a, transposed):\n",
" # `tile_transpose` assumes it is launched with a 32x32 block dimension,\n",
" # and that `a` is a multiple of these dimensions.\n",
" \n",
" # 1) Create 32x32 shared memory array.\n",
" tile = cuda.shared.array((32, 33), numba_types.float32)\n",
"\n",
" # Compute offsets into global input array.\n",
" x = cuda.blockIdx.x * cuda.blockDim.x + cuda.threadIdx.x\n",
" y = cuda.blockIdx.y * cuda.blockDim.y + cuda.threadIdx.y\n",
" \n",
" # 2) Make coalesced read from global memory into shared memory array.\n",
" # Note the use of local thread indices for the shared memory write,\n",
" # and global offsets for global memory read.\n",
" tile[cuda.threadIdx.y, cuda.threadIdx.x] = a[y, x]\n",
"\n",
" # 3) Wait for all threads in the block to finish updating shared memory.\n",
" cuda.syncthreads()\n",
" \n",
" # 4) Calculate transposed location for the shared memory array tile\n",
" # to be written back to global memory.\n",
" t_x = cuda.blockIdx.y * cuda.blockDim.y + cuda.threadIdx.x\n",
" t_y = cuda.blockIdx.x * cuda.blockDim.x + cuda.threadIdx.y\n",
"\n",
" # 5) Write back to global memory,\n",
" # transposing each element within the shared memory array.\n",
" transposed[t_y, t_x] = tile[cuda.threadIdx.x, cuda.threadIdx.y]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Check Performance"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Assuming you have correctly resolved the bank conflicts, this kernel should run significantly faster than both the naive transpose kernel, and, the shared memory (with bank conflicts) transpose kernel. In order to pass the assessment, your kernel will need to run on average in less than 840 µs.\n",
"\n",
"The first value printed by running the following cell will give you the average run time of your kernel."
]
},
{
"cell_type": "code",
"execution_count": 97,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"804 µs ± 2.89 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n"
]
}
],
"source": [
"%timeit tile_transpose_conflict_free[blocks, threads_per_block](d_a, d_transposed); cuda.synchronize()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Check Correctness"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In order to pass the assessment, your kernel also needs to work correctly. Run the following 2 cells to confirm this is true."
]
},
{
"cell_type": "code",
"execution_count": 98,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"result = d_transposed.copy_to_host()\n",
"expected = a.T"
]
},
{
"cell_type": "code",
"execution_count": 99,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 99,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"np.array_equal(result, expected)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Run the Assessment"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"If you have completed the refactor, observed it's run time to be less than 840 µs, and confirmed that it runs correctly, execute the following cells to run the assessment against your kernel definition."
]
},
{
"cell_type": "code",
"execution_count": 100,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"from assessment import assess"
]
},
{
"cell_type": "code",
"execution_count": 101,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Your function took 808.28 µs to run.\n",
"\n",
"Your function runs fast enough (less than 840 µs): True\n",
"\n",
"Your function returns the correct results: True\n",
"\n",
"Congratulations, you passed! See the instructions below for how to get credit for your work to count toward a certificate in the course.\n"
]
}
],
"source": [
"assess(tile_transpose_conflict_free)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Get Credit for Your Work"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"After successfully passing the assessment above, revisit the webpage where you launched this interactive environment and click on the **\"ASSESS TASK\"** button as shown in the screenshot below. Doing so will give you credit for this part of the workshop that counts towards earning a **certificate of competency** for the entire course."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Summary"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now that you have completed this session you are able to:\n",
"\n",
"* Write CUDA kernels that benefit from coalesced memory access patterns.\n",
"* Work with multi-dimensional grids and thread blocks.\n",
"* Use shared memory to coordinate threads within a block.\n",
"* Use shared memory to facilitate coalesced memory access patterns.\n",
"* Resolve shared memory bank conflicts."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Download Content"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To download the contents of this notebook, execute the following cell and then click the download link below. Note: If you run this notebook on a local Jupyter server, you can expect some of the file path links in the notebook to be broken as they are shaped to our own platform. You can still navigate to the files through the Jupyter file navigator."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"!tar -zcvf section3.tar.gz ."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"[Download files from this section.](files/section3.tar.gz)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.10"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
|