Conversation

malfet

I.e. it feels reasonable to always call at::cuda::gemm rather than at::cuda::bgemm when num_batches == 1
After the change, benchmarking torch built with CUDA-12 using following perf script on A100 are as follows:

Shapebmm_timemm_timeslow down (%)
1x1x409614.1814.31-0.89
1x1x819214.3714.37-0.05
1x1x1638414.0314.12-0.68
1x1x3276814.1914.24-0.35
1x1x6553614.8514.522.30
1x1x13107214.0314.07-0.33
128x128x12811.3411.062.56
256x256x25614.8514.403.15
512x512x51227.2227.22-0.01
1024x1024x1024129.66129.500.12
2048x2048x2048972.18973.24-0.11
129x127x12911.2111.25-0.39
257x255x25714.5014.430.44
513x511x51329.0129.010.01
1025x1023x1025137.65137.640.01
2049x2047x2049982.58982.65-0.01
4097x3x409786.6586.640.01
8193x3x8193384.02383.960.02
16385x3x163851106.731107.32-0.05
32769x3x327694739.494739.480.00
65537x3x6553717377.7817378.74-0.01
4097x5x409787.0987.12-0.03
8193x5x8193301.38301.360.01
16385x5x163851107.381108.04-0.06
32769x5x327694743.734744.07-0.01
65537x5x6553717392.3217395.42-0.02
4097x7x409787.1787.19-0.02
8193x7x8193301.94302.00-0.02
16385x7x163851107.171106.790.03
32769x7x327694747.154747.130.00
65537x7x6553717403.8517405.02-0.01

Fixes perf problem reported in #114911

@malfetmalfet requested a review from ngimel December 1, 2023 23:12
@pytorch-botPyTorch Bot

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/114992

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit 8029afc with merge base 453d509 (image):

UNSTABLE - The following job failed but was likely due to flakiness present on trunk and has been marked as unstable:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@malfetmalfet added release notes: cudarelease notes categorytopic: performancetopic categorylabels Dec 1, 2023
@ptrblckptrblck requested a review from eqy December 1, 2023 23:14

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Einsum failure looks like it could be real, I'm guessing it could be because self here might not have a batch dimension? (e.g., torch.baddbmm(torch.randn(32, 32), torch.randn(1, 32, 32), torch.randn(1, 32, 32)))

EDIT: Not sure if that's the issue as squeeze should be a no-op in that case anyway...

@malfet

Yes, test failures looks related, but I can not reproduce it in locally.
And torch.randn(32, 32).squeeze(0) is a no-op, but torch.rand(1, 32).squeeze(0) is not, so...

@malfetmalfet force-pushed the malfet/cuda-fall-back-to-addm branch from 83c4c05 to 42ada62 Compare December 5, 2023 04:05
@malfetmalfet requested a review from mruberry as a code owner December 5, 2023 04:05
@malfet

Einsum failure looks like it could be real

After some debugging, it is indeed related to a PR, but in some weird way: bfloat16 is not supported by bmm on sm52, but mm works fine, so backwards now works on older GPUs...

Tweaked opinfo limits a bit

eqy
eqy approved these changes Dec 6, 2023
@malfet

@pytorct merge

@pytorch-botpytorch-bot bot added the ciflow/trunkTrigger trunk jobs on your pull requestlabel Dec 8, 2023
@pytorchmergebot

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@malfetmalfet deleted the malfet/cuda-fall-back-to-addm branch December 18, 2023 19:03
dmenig pushed a commit to dmenig/pytorch that referenced this pull request Dec 21, 2023
I.e. it feels reasonable to always call `at::cuda::gemm` rather than `at::cuda::bgemm` when num_batches == 1
After the change, benchmarking torch built with CUDA-12 using  [following perf script](https://gist..com/malfet/6a17156d7f5663b8b12054a1beff3fe1) on A100  are as follows:
|      Shape     |  bmm_time |  mm_time  | slow down (%) |
| -------------- | --------- | --------- | ------------- |
|    1x1x4096    |   14.18   |   14.31   |     -0.89     |
|    1x1x8192    |   14.37   |   14.37   |     -0.05     |
|   1x1x16384    |   14.03   |   14.12   |     -0.68     |
|   1x1x32768    |   14.19   |   14.24   |     -0.35     |
|   1x1x65536    |   14.85   |   14.52   |     2.30      |
|   1x1x131072   |   14.03   |   14.07   |     -0.33     |
|  128x128x128   |   11.34   |   11.06   |     2.56      |
|  256x256x256   |   14.85   |   14.40   |     3.15      |
|  512x512x512   |   27.22   |   27.22   |     -0.01     |
| 1024x1024x1024 |  129.66   |  129.50   |     0.12      |
| 2048x2048x2048 |  972.18   |  973.24   |     -0.11     |
|  129x127x129   |   11.21   |   11.25   |     -0.39     |
|  257x255x257   |   14.50   |   14.43   |     0.44      |
|  513x511x513   |   29.01   |   29.01   |     0.01      |
| 1025x1023x1025 |  137.65   |  137.64   |     0.01      |
| 2049x2047x2049 |  982.58   |  982.65   |     -0.01     |
|  4097x3x4097   |   86.65   |   86.64   |     0.01      |
|  8193x3x8193   |  384.02   |  383.96   |     0.02      |
| 16385x3x16385  |  1106.73  |  1107.32  |     -0.05     |
| 32769x3x32769  |  4739.49  |  4739.48  |     0.00      |
| 65537x3x65537  | 17377.78  | 17378.74  |     -0.01     |
|  4097x5x4097   |   87.09   |   87.12   |     -0.03     |
|  8193x5x8193   |  301.38   |  301.36   |     0.01      |
| 16385x5x16385  |  1107.38  |  1108.04  |     -0.06     |
| 32769x5x32769  |  4743.73  |  4744.07  |     -0.01     |
| 65537x5x65537  | 17392.32  | 17395.42  |     -0.02     |
|  4097x7x4097   |   87.17   |   87.19   |     -0.02     |
|  8193x7x8193   |  301.94   |  302.00   |     -0.02     |
| 16385x7x16385  |  1107.17  |  1106.79  |     0.03      |
| 32769x7x32769  |  4747.15  |  4747.13  |     0.00      |
| 65537x7x65537  | 17403.85  | 17405.02  |     -0.01     |

Fixes perf problem reported in pytorch#114911
Pull Request resolved: pytorch#114992
Approved by: https://.com/Skylion007, https://.com/eqy
atalman pushed a commit to atalman/pytorch that referenced this pull request Dec 28, 2023
I.e. it feels reasonable to always call `at::cuda::gemm` rather than `at::cuda::bgemm` when num_batches == 1
After the change, benchmarking torch built with CUDA-12 using  [following perf script](https://gist..com/malfet/6a17156d7f5663b8b12054a1beff3fe1) on A100  are as follows:
|      Shape     |  bmm_time |  mm_time  | slow down (%) |
| -------------- | --------- | --------- | ------------- |
|    1x1x4096    |   14.18   |   14.31   |     -0.89     |
|    1x1x8192    |   14.37   |   14.37   |     -0.05     |
|   1x1x16384    |   14.03   |   14.12   |     -0.68     |
|   1x1x32768    |   14.19   |   14.24   |     -0.35     |
|   1x1x65536    |   14.85   |   14.52   |     2.30      |
|   1x1x131072   |   14.03   |   14.07   |     -0.33     |
|  128x128x128   |   11.34   |   11.06   |     2.56      |
|  256x256x256   |   14.85   |   14.40   |     3.15      |
|  512x512x512   |   27.22   |   27.22   |     -0.01     |
| 1024x1024x1024 |  129.66   |  129.50   |     0.12      |
| 2048x2048x2048 |  972.18   |  973.24   |     -0.11     |
|  129x127x129   |   11.21   |   11.25   |     -0.39     |
|  257x255x257   |   14.50   |   14.43   |     0.44      |
|  513x511x513   |   29.01   |   29.01   |     0.01      |
| 1025x1023x1025 |  137.65   |  137.64   |     0.01      |
| 2049x2047x2049 |  982.58   |  982.65   |     -0.01     |
|  4097x3x4097   |   86.65   |   86.64   |     0.01      |
|  8193x3x8193   |  384.02   |  383.96   |     0.02      |
| 16385x3x16385  |  1106.73  |  1107.32  |     -0.05     |
| 32769x3x32769  |  4739.49  |  4739.48  |     0.00      |
| 65537x3x65537  | 17377.78  | 17378.74  |     -0.01     |
|  4097x5x4097   |   87.09   |   87.12   |     -0.03     |
|  8193x5x8193   |  301.38   |  301.36   |     0.01      |
| 16385x5x16385  |  1107.38  |  1108.04  |     -0.06     |
| 32769x5x32769  |  4743.73  |  4744.07  |     -0.01     |
| 65537x5x65537  | 17392.32  | 17395.42  |     -0.02     |
|  4097x7x4097   |   87.17   |   87.19   |     -0.02     |
|  8193x7x8193   |  301.94   |  302.00   |     -0.02     |
| 16385x7x16385  |  1107.17  |  1106.79  |     0.03      |
| 32769x7x32769  |  4747.15  |  4747.13  |     0.00      |
| 65537x7x65537  | 17403.85  | 17405.02  |     -0.01     |

Fixes perf problem reported in pytorch#114911
Pull Request resolved: pytorch#114992
Approved by: https://.com/Skylion007, https://.com/eqy
atalman added a commit that referenced this pull request Jan 2, 2024
I.e. it feels reasonable to always call `at::cuda::gemm` rather than `at::cuda::bgemm` when num_batches == 1
After the change, benchmarking torch built with CUDA-12 using  [following perf script](https://gist..com/malfet/6a17156d7f5663b8b12054a1beff3fe1) on A100  are as follows:
|      Shape     |  bmm_time |  mm_time  | slow down (%) |
| -------------- | --------- | --------- | ------------- |
|    1x1x4096    |   14.18   |   14.31   |     -0.89     |
|    1x1x8192    |   14.37   |   14.37   |     -0.05     |
|   1x1x16384    |   14.03   |   14.12   |     -0.68     |
|   1x1x32768    |   14.19   |   14.24   |     -0.35     |
|   1x1x65536    |   14.85   |   14.52   |     2.30      |
|   1x1x131072   |   14.03   |   14.07   |     -0.33     |
|  128x128x128   |   11.34   |   11.06   |     2.56      |
|  256x256x256   |   14.85   |   14.40   |     3.15      |
|  512x512x512   |   27.22   |   27.22   |     -0.01     |
| 1024x1024x1024 |  129.66   |  129.50   |     0.12      |
| 2048x2048x2048 |  972.18   |  973.24   |     -0.11     |
|  129x127x129   |   11.21   |   11.25   |     -0.39     |
|  257x255x257   |   14.50   |   14.43   |     0.44      |
|  513x511x513   |   29.01   |   29.01   |     0.01      |
| 1025x1023x1025 |  137.65   |  137.64   |     0.01      |
| 2049x2047x2049 |  982.58   |  982.65   |     -0.01     |
|  4097x3x4097   |   86.65   |   86.64   |     0.01      |
|  8193x3x8193   |  384.02   |  383.96   |     0.02      |
| 16385x3x16385  |  1106.73  |  1107.32  |     -0.05     |
| 32769x3x32769  |  4739.49  |  4739.48  |     0.00      |
| 65537x3x65537  | 17377.78  | 17378.74  |     -0.01     |
|  4097x5x4097   |   87.09   |   87.12   |     -0.03     |
|  8193x5x8193   |  301.38   |  301.36   |     0.01      |
| 16385x5x16385  |  1107.38  |  1108.04  |     -0.06     |
| 32769x5x32769  |  4743.73  |  4744.07  |     -0.01     |
| 65537x5x65537  | 17392.32  | 17395.42  |     -0.02     |
|  4097x7x4097   |   87.17   |   87.19   |     -0.02     |
|  8193x7x8193   |  301.94   |  302.00   |     -0.02     |
| 16385x7x16385  |  1107.17  |  1106.79  |     0.03      |
| 32769x7x32769  |  4747.15  |  4747.13  |     0.00      |
| 65537x7x65537  | 17403.85  | 17405.02  |     -0.01     |

Fixes perf problem reported in #114911
Pull Request resolved: #114992
Approved by: https://.com/Skylion007, https://.com/eqy

Co-authored-by: Nikita Shulga <[email protected]>
Sign up for free to join this conversation on . Already have an account? Sign in to comment
ciflow/trunkTrigger trunk jobs on your pull requestMergedrelease notes categorytopic: performancetopic category
None yet

Successfully merging this pull request may close these issues.