Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2325
Note: Links to docs will display an error until the docs builds have been completed. ❌ 2 New Failures, 1 Unrelated FailureAs of commit 186708f with merge base f0f1f6c ( NEW FAILURES - The following jobs have failed:
FLAKY - The following job failed but was likely due to flakiness present on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
isinstance(self.w1, torchao.quantization.linear_activation_quantized_tensor.LinearActivationQuantizedTensor) and | ||
isinstance(self.w1.original_weight_tensor._layout, torchao.dtypes.floatx.float8_layout.Float8Layout) | ||
): | ||
final_out = fp8_dq_moe_op(x, self.w1, self.w2, self.w3, expert_indices, scores) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is it possible to call this op without modifying the source model?
is there a gropup_mm for bfloat16 that we can overwrite and dis to scaled_grouped_mmm?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, there is _grouped_mm
in PyTorch core that does that.
HDCharles Jun 11, 2025 •edited
LoadingUh oh!
There was an error while loading. Please reload this page.
edited
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that's a better integration point but i'm not sure i'll be able to complete that before i have to head out on leave.
also i'd probably make that a separate PR instead of combining everything into one since that would be a significant change to the base moe integration.
PR to hopefully remove need for padding groups is here: pytorch/pytorch#155466. |
torchao//moe_quant/kernels.py Outdated
alignment = 16 | ||
if _torchtitan_available: | ||
num_ranks = 1 | ||
padded_indices, m_offsets = torchtitan_pad(num_tokens_per_expert, alignment, num_ranks) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
heads up, soon we won't need padding once #155466 lands
input_fp8[valid_values] = q_input_data[token_shuffle] | ||
input_scale[valid_values] = q_input_scale[token_shuffle] if q_input_scale.numel()>1 else q_input_scale | ||
if use_fbgemm_kernel: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we have fbgemm-like kernels available via autotuning in torch.compile, thanks to #155138, do you think we still need separate fbgemm path?
Summary: extending the torchao moe support to have more performant kernels. This PR supports both scaled_grouped_mm and fbgemm's grouped_gemm_fp8_rowwise though it seems like grouped_gemm_fp8_rowwise is a bit buggy (need to make a clear repro) todo: run benchmarks, debug fbgemm kernel, unit tests Test Plan: Reviewers: Subscribers: Tasks: Tags:
PR pytorch/pytorch#155466, that makes it possible to avoid padding, is merged. Here is a quick to remove padding (note that it also disables FBGEMM altogether, so I'm not completely sure that scale tensor shape adjustment I made here are correct, but in any case this , together with latest PyTorch used, will make all the tests in |
@alexsamardzic we still would need padding for backward where K could possibly become 0? |
This PR is not concerned about backward, but I would say @danielvegamyhre is touching on it: #2405. In any case, you have a point, here is a diff for diffdiff --git a/test/test_matmul_cuda.py b/test/test_matmul_cuda.py
index 4e64c807425..96667a79440 100644
--- a/test/test_matmul_cuda.py
+++ b/test/test_matmul_cuda.py
@@ -354,15 +354,15 @@ class TestMatmulCuda(TestCase):
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS")
@xfailIfSM100OrLater
@unittest.skipIf(not SM90OrLater, "Grouped gemm supported on SM90")
- @parametrize("strided", [False, True])
- @parametrize("a_row_major", [False, True])
- @parametrize("b_row_major", [False, True])
- @parametrize("use_torch_compile", [False, True])
+ @parametrize("strided", [True])
+ @parametrize("a_row_major", [True])
+ @parametrize("b_row_major", [True])
+ @parametrize("use_torch_compile", [True, False])
def test_grouped_gemm_2d_3d(self, strided, a_row_major, b_row_major, use_torch_compile):
device = "cuda"
dtype = torch.bfloat16
s_int = int(strided)
- m, n, k, n_groups = 16, 32, 64, 4
+ m, n, k, n_groups = 3, 32, 64, 5
if a_row_major:
a = torch.randn(m * n_groups, k * (1 + s_int), device=device, dtype=dtype)[:, :k]
else:
@@ -388,6 +388,7 @@ class TestMatmulCuda(TestCase):
a.grad = None
b.grad = None
offs = torch.arange(m, n_groups * m + 1, m, device="cuda", dtype=torch.int32)
+ offs = torch.tensor([0, 1, 6, 6, 15], device="cuda", dtype=torch.int32)
if check_zero_size:
offs[0] = offs[1] If offsets changed say to Edit: see here. |
Here is slightly changed diff: 096_fuse_moeb-diff.txt. To be applied after PR rebased on latest main. Some end-to-end performance numbers for Mixtral model, for current version of the PR: First, this is to be applied to enforce auto-tuning for all cases: diff --git a/torchao/_models/mixtral-moe/generate.py b/torchao/_models/mixtral-moe/generate.py
index 11a53043..10da20f7 100644
--- a/torchao/_models/mixtral-moe/generate.py
+++ b/torchao/_models/mixtral-moe/generate.py
@@ -337,10 +337,10 @@ def main(
if batch_size == 1 and (isinstance(moe_quant, str) and "base" in moe_quant):
decode_one_token = torch.compile(
- decode_one_token, mode="reduce-overhead", fullgraph=True
+ decode_one_token, mode="max-autotune"
)
else:
- decode_one_token = torch.compile(decode_one_token, mode="reduce-overhead")
+ decode_one_token = torch.compile(decode_one_token, mode="max-autotune")
if args.compile_prefill:
prefill = torch.compile(prefill, fullgraph=True, dynamic=True) For each run below, First, a variant that goes through this branch of the model forward function: $ python generate.py --compile --moe_quant fp8dq
Average tokens/sec: 10.23
Memory used: 51.21 GB
model size: 48.37 Then a variant that goes through this branch: $ python generate.py --compile --moe_quant fp8dq-base
Average tokens/sec: 56.34
Memory used: 59.14 GB
model size: 48.37 Finally, for a variant that will utilize auto-tuned $ python generate.py --compile --moe_quant fp8dq-base
Average tokens/sec: 101.24
Memory used: 59.14 GB
model size: 48.37 If auto-tuning for Again, this all could be further improved, leaving it at that for now. |
@alexsamardzic so pytorch now with removed paddiing restrictions is strictly better than fbgemm? |
The FBGEMM kernel is not included in the results above. To have it activated, on top of changes mentioned for the last run (that was about using $ python generate.py --compile --moe_quant fp8dq-base
Average tokens/sec: 23.51
Memory used: 59.14 GB
model size: 48.37 Plus, the output is garbage. However, regarding the performance, note this, i.e. the compilation is at the moment disabled around calls to FBGEMM kernel; and if enabled it would error out. I'm not sure @HDCharles would be interested in working further on that branch, but IMO both FBGEMM and PyTorch Triton kernels should have similar performance, so FBGEMM kernel usage may be safely skipped. |
Summary:
current status:
both kernels are working. The padding is a significant issue with compile for the pytorch kernel while the fbgemm kernel doesn't seem compatible with compile. Hopefully this can be handled using the changes mentioned below to avoid the data dependent padding.
todo:
test the no-padding compilable pytorch kernel
change base integration to grouped_gemm (another PR)
Test Plan:
Reviewers:
Subscribers:
Tasks:
Tags: