From b799f10c708a223ca30e622d658790bee21a553e Mon Sep 17 00:00:00 2001 From: Harald van Dijk Date: Thu, 25 Jun 2026 01:50:01 +0100 Subject: [PATCH 1/7] [mlir][dxsa] Add bfi instruction This commit generalises the DXSA_*Op to allow arbitrary numbers of arguments. --- .../mlir/Dialect/DXSA/IR/DXSABitwiseOps.td | 21 ++++++++ .../mlir/Dialect/DXSA/IR/DXSAOpBase.td | 48 ++++++++----------- mlir/lib/Target/DXSA/BinaryParser.cpp | 2 + mlir/test/Target/DXSA/bfi.test | 8 ++++ 4 files changed, 50 insertions(+), 29 deletions(-) create mode 100644 mlir/test/Target/DXSA/bfi.test diff --git a/mlir/include/mlir/Dialect/DXSA/IR/DXSABitwiseOps.td b/mlir/include/mlir/Dialect/DXSA/IR/DXSABitwiseOps.td index 8a4b91c3df3e..0b726b372ba9 100644 --- a/mlir/include/mlir/Dialect/DXSA/IR/DXSABitwiseOps.td +++ b/mlir/include/mlir/Dialect/DXSA/IR/DXSABitwiseOps.td @@ -34,6 +34,27 @@ def DXSA_And : DXSA_BinaryOp<"and"> { }]; } +//===----------------------------------------------------------------------===// +// dxsa.bfi +//===----------------------------------------------------------------------===// + +def DXSA_BFI : DXSA_PlainOp<"bfi", 1, 4> { + let summary = "bit field insert"; + let description = [{ + The `dxsa.bfi` operation takes a bit range from the LSB of a number + and places that number of bits in another number at any offset. + + `$src0` specifies the bitfield width to take from `$src2`. `$src1` + specifies the offset at which to insert the bitfield in `$src3`. + + Example: + + ```mlir + dxsa.bfi r<0, >, l(0x1E), l(0x2), v<0, >, l(0x1) + ``` + }]; +} + //===----------------------------------------------------------------------===// // dxsa.bfrev //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/DXSA/IR/DXSAOpBase.td b/mlir/include/mlir/Dialect/DXSA/IR/DXSAOpBase.td index 6f67ca4c3329..c49df3209008 100644 --- a/mlir/include/mlir/Dialect/DXSA/IR/DXSAOpBase.td +++ b/mlir/include/mlir/Dialect/DXSA/IR/DXSAOpBase.td @@ -25,40 +25,30 @@ class DXSA_Op traits = []> : Op; //===----------------------------------------------------------------------===// -// DXSA shared bases for ops with inline operands +// DXSA shared base for ops with inline operands //===----------------------------------------------------------------------===// -class DXSA_UnaryOp : DXSA_Op { - let arguments = (ins - DXSA_DstOperandAttr:$dst, - DXSA_SrcOperandAttr:$src, - OptionalAttr:$precise); +class DXSA_PlainOp : DXSA_Op { + defvar args = !foldl(!dag(ins, !listsplat(0, !add(dsts, srcs, 1)), ?), + !range(!add(dsts, srcs, 1)), + acc, n, + !cond(!lt(n, dsts): + !setdagname(!setdagarg(acc, n, DXSA_DstOperandAttr), n, !cond(!eq(dsts, 1): "dst", true: !strconcat("dst", !cast(n)))), + !lt(n, !add(dsts, srcs)): + !setdagname(!setdagarg(acc, n, DXSA_SrcOperandAttr), n, !cond(!eq(srcs, 1): "src", + !and(!eq(srcs, 2), !eq(n, dsts)): "lhs", + !eq(srcs, 2): "rhs", + true: !strconcat("src", !cast(!sub(n, dsts))))), + true: + !setdagname(!setdagarg(acc, n, OptionalAttr), n, "precise"))); + let arguments = args; let results = (outs); let assemblyFormat = - "(`precise` $precise^)? $dst `,` $src attr-dict"; + !strconcat(!foldl("(`precise` $precise^)?", !range(!add(dsts, srcs)), acc, n, !strconcat(acc, !cond(!eq(n, 0): " $", true: " `,` $"), !getdagname(args, n))), " attr-dict"); } -class DXSA_BinaryOp : DXSA_Op { - let arguments = (ins - DXSA_DstOperandAttr:$dst, - DXSA_SrcOperandAttr:$lhs, - DXSA_SrcOperandAttr:$rhs, - OptionalAttr:$precise); - let results = (outs); - let assemblyFormat = - "(`precise` $precise^)? $dst `,` $lhs `,` $rhs attr-dict"; -} - -class DXSA_TernaryOp : DXSA_Op { - let arguments = (ins - DXSA_DstOperandAttr:$dst, - DXSA_SrcOperandAttr:$src0, - DXSA_SrcOperandAttr:$src1, - DXSA_SrcOperandAttr:$src2, - OptionalAttr:$precise); - let results = (outs); - let assemblyFormat = - "(`precise` $precise^)? $dst `,` $src0 `,` $src1 `,` $src2 attr-dict"; -} +class DXSA_UnaryOp : DXSA_PlainOp; +class DXSA_BinaryOp : DXSA_PlainOp; +class DXSA_TernaryOp : DXSA_PlainOp; #endif // MLIR_DIALECT_DXSA_IR_DXSAOPBASE diff --git a/mlir/lib/Target/DXSA/BinaryParser.cpp b/mlir/lib/Target/DXSA/BinaryParser.cpp index b783e6be85f8..ff42301984e9 100644 --- a/mlir/lib/Target/DXSA/BinaryParser.cpp +++ b/mlir/lib/Target/DXSA/BinaryParser.cpp @@ -2383,6 +2383,8 @@ class Parser { // Bitwise instructions case D3D10_SB_OPCODE_AND: return PLAIN_OP(And, 1, 2, HasPreciseAttr::Yes); + case D3D11_SB_OPCODE_BFI: + return PLAIN_OP(BFI, 1, 4, HasPreciseAttr::Yes); case D3D11_SB_OPCODE_BFREV: return PLAIN_OP(BFRev, 1, 1, HasPreciseAttr::Yes); case D3D11_SB_OPCODE_COUNTBITS: diff --git a/mlir/test/Target/DXSA/bfi.test b/mlir/test/Target/DXSA/bfi.test new file mode 100644 index 000000000000..1f66e0ebe668 --- /dev/null +++ b/mlir/test/Target/DXSA/bfi.test @@ -0,0 +1,8 @@ +// RUN: mlir-translate --import-dxsa-hex %s | FileCheck %s +// RUN: mlir-translate --import-dxsa-hex %s | mlir-opt --verify-roundtrip + +// CHECK: dxsa.module { + +// CHECK-NEXT: dxsa.bfi r<0, >, l(0x1E), l(0x2), v<0, >, l(0x1) +0x0B00008C, 0x00100022, 0x00000000, 0x00004001, 0x0000001E, 0x00004001, 0x00000002, 0x0010100A, 0x00000000, 0x00004001, 0x00000001 +// CHECK-NEXT: } From aa7f911a60106c25420a05d27ea2d254068c7ca9 Mon Sep 17 00:00:00 2001 From: Harald van Dijk Date: Thu, 25 Jun 2026 20:14:57 +0100 Subject: [PATCH 2/7] Simplify TableGen, prepare for hasprecise = 0 --- .../mlir/Dialect/DXSA/IR/DXSAOpBase.td | 32 +++++++++++-------- 1 file changed, 18 insertions(+), 14 deletions(-) diff --git a/mlir/include/mlir/Dialect/DXSA/IR/DXSAOpBase.td b/mlir/include/mlir/Dialect/DXSA/IR/DXSAOpBase.td index c49df3209008..7566fa60292e 100644 --- a/mlir/include/mlir/Dialect/DXSA/IR/DXSAOpBase.td +++ b/mlir/include/mlir/Dialect/DXSA/IR/DXSAOpBase.td @@ -28,23 +28,27 @@ class DXSA_Op traits = []> : // DXSA shared base for ops with inline operands //===----------------------------------------------------------------------===// -class DXSA_PlainOp : DXSA_Op { - defvar args = !foldl(!dag(ins, !listsplat(0, !add(dsts, srcs, 1)), ?), - !range(!add(dsts, srcs, 1)), +class DXSA_PlainOp : DXSA_Op { + defvar dstnames = !cond(!eq(dsts, 1): ["dst"], + true: !foreach(n, !range(dsts), !strconcat("dst", !cast(n)))); + defvar srcnames = !cond(!eq(srcs, 1): ["src"], + !eq(srcs, 2): ["lhs", "rhs"], + true: !foreach(n, !range(srcs), !strconcat("src", !cast(n)))); + defvar precisename = !listsplat("precise", !cast(hasprecise)); + defvar argnames = !listconcat(dstnames, srcnames, precisename); + + defvar args = !foldl(!dag(ins, ?, argnames), + !range(!add(dsts, srcs, hasprecise)), acc, n, - !cond(!lt(n, dsts): - !setdagname(!setdagarg(acc, n, DXSA_DstOperandAttr), n, !cond(!eq(dsts, 1): "dst", true: !strconcat("dst", !cast(n)))), - !lt(n, !add(dsts, srcs)): - !setdagname(!setdagarg(acc, n, DXSA_SrcOperandAttr), n, !cond(!eq(srcs, 1): "src", - !and(!eq(srcs, 2), !eq(n, dsts)): "lhs", - !eq(srcs, 2): "rhs", - true: !strconcat("src", !cast(!sub(n, dsts))))), - true: - !setdagname(!setdagarg(acc, n, OptionalAttr), n, "precise"))); + !cond(!lt(n, dsts): !setdagarg(acc, n, DXSA_DstOperandAttr), + !lt(n, !add(dsts, srcs)): !setdagarg(acc, n, DXSA_SrcOperandAttr), + true: !setdagarg(acc, n, OptionalAttr))); let arguments = args; let results = (outs); - let assemblyFormat = - !strconcat(!foldl("(`precise` $precise^)?", !range(!add(dsts, srcs)), acc, n, !strconcat(acc, !cond(!eq(n, 0): " $", true: " `,` $"), !getdagname(args, n))), " attr-dict"); + let assemblyFormat = !interleave(!listconcat( + !cond(hasprecise: ["(`precise` $precise^)?"], true: []), + !tail(!cond(!eq(!add(dsts, srcs), 0): [""], true: !listflatten(!foreach(n, !listconcat(dstnames, srcnames), ["`,`", !strconcat("$", n)])))), + ["attr-dict"]), " "); } class DXSA_UnaryOp : DXSA_PlainOp; From 7c0216f0bb4c9b347e7b1ded9ddb2cc40d787d5a Mon Sep 17 00:00:00 2001 From: Harald van Dijk Date: Thu, 25 Jun 2026 20:22:13 +0100 Subject: [PATCH 3/7] Simplify TableGen further --- mlir/include/mlir/Dialect/DXSA/IR/DXSAOpBase.td | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mlir/include/mlir/Dialect/DXSA/IR/DXSAOpBase.td b/mlir/include/mlir/Dialect/DXSA/IR/DXSAOpBase.td index 7566fa60292e..2bfd9730df34 100644 --- a/mlir/include/mlir/Dialect/DXSA/IR/DXSAOpBase.td +++ b/mlir/include/mlir/Dialect/DXSA/IR/DXSAOpBase.td @@ -39,10 +39,10 @@ class DXSA_PlainOp : DX defvar args = !foldl(!dag(ins, ?, argnames), !range(!add(dsts, srcs, hasprecise)), - acc, n, - !cond(!lt(n, dsts): !setdagarg(acc, n, DXSA_DstOperandAttr), - !lt(n, !add(dsts, srcs)): !setdagarg(acc, n, DXSA_SrcOperandAttr), - true: !setdagarg(acc, n, OptionalAttr))); + acc, n, !setdagarg(acc, n, !cond( + !lt(n, dsts): DXSA_DstOperandAttr, + !lt(n, !add(dsts, srcs)): DXSA_SrcOperandAttr, + true: OptionalAttr))); let arguments = args; let results = (outs); let assemblyFormat = !interleave(!listconcat( From 9a219cf9eec6fc6731789449224496a6e6c76049 Mon Sep 17 00:00:00 2001 From: Harald van Dijk Date: Fri, 26 Jun 2026 17:47:00 +0100 Subject: [PATCH 4/7] Remove hasprecise, add DXSA_NullaryOp --- mlir/include/mlir/Dialect/DXSA/IR/DXSAOpBase.td | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/mlir/include/mlir/Dialect/DXSA/IR/DXSAOpBase.td b/mlir/include/mlir/Dialect/DXSA/IR/DXSAOpBase.td index 2bfd9730df34..492ad424b0fa 100644 --- a/mlir/include/mlir/Dialect/DXSA/IR/DXSAOpBase.td +++ b/mlir/include/mlir/Dialect/DXSA/IR/DXSAOpBase.td @@ -28,17 +28,17 @@ class DXSA_Op traits = []> : // DXSA shared base for ops with inline operands //===----------------------------------------------------------------------===// -class DXSA_PlainOp : DXSA_Op { +class DXSA_PlainOp : DXSA_Op { defvar dstnames = !cond(!eq(dsts, 1): ["dst"], true: !foreach(n, !range(dsts), !strconcat("dst", !cast(n)))); defvar srcnames = !cond(!eq(srcs, 1): ["src"], !eq(srcs, 2): ["lhs", "rhs"], true: !foreach(n, !range(srcs), !strconcat("src", !cast(n)))); - defvar precisename = !listsplat("precise", !cast(hasprecise)); + defvar precisename = ["precise"]; defvar argnames = !listconcat(dstnames, srcnames, precisename); defvar args = !foldl(!dag(ins, ?, argnames), - !range(!add(dsts, srcs, hasprecise)), + !range(!add(dsts, srcs, 1)), acc, n, !setdagarg(acc, n, !cond( !lt(n, dsts): DXSA_DstOperandAttr, !lt(n, !add(dsts, srcs)): DXSA_SrcOperandAttr, @@ -46,13 +46,14 @@ class DXSA_PlainOp : DX let arguments = args; let results = (outs); let assemblyFormat = !interleave(!listconcat( - !cond(hasprecise: ["(`precise` $precise^)?"], true: []), + ["(`precise` $precise^)?"], !tail(!cond(!eq(!add(dsts, srcs), 0): [""], true: !listflatten(!foreach(n, !listconcat(dstnames, srcnames), ["`,`", !strconcat("$", n)])))), ["attr-dict"]), " "); } -class DXSA_UnaryOp : DXSA_PlainOp; -class DXSA_BinaryOp : DXSA_PlainOp; -class DXSA_TernaryOp : DXSA_PlainOp; +class DXSA_NullaryOp : DXSA_PlainOp; +class DXSA_UnaryOp : DXSA_PlainOp; +class DXSA_BinaryOp : DXSA_PlainOp; +class DXSA_TernaryOp : DXSA_PlainOp; #endif // MLIR_DIALECT_DXSA_IR_DXSAOPBASE From d2127c4ab449e24fbf7b6a21b52f1040a2637031 Mon Sep 17 00:00:00 2001 From: Harald van Dijk Date: Mon, 29 Jun 2026 15:43:22 +0100 Subject: [PATCH 5/7] Rename to DXSA_BaseOp, drop HasPreciseAttr (always true), use throughout --- .../mlir/Dialect/DXSA/IR/DXSAAtomicOps.td | 49 +---- .../mlir/Dialect/DXSA/IR/DXSABitwiseOps.td | 2 +- .../mlir/Dialect/DXSA/IR/DXSAFPArithOps.td | 23 +- .../mlir/Dialect/DXSA/IR/DXSAOpBase.td | 36 ++-- mlir/lib/Target/DXSA/BinaryParser.cpp | 201 +++++++++--------- 5 files changed, 117 insertions(+), 194 deletions(-) diff --git a/mlir/include/mlir/Dialect/DXSA/IR/DXSAAtomicOps.td b/mlir/include/mlir/Dialect/DXSA/IR/DXSAAtomicOps.td index b83f0a9a0c37..c3dbfa45883c 100644 --- a/mlir/include/mlir/Dialect/DXSA/IR/DXSAAtomicOps.td +++ b/mlir/include/mlir/Dialect/DXSA/IR/DXSAAtomicOps.td @@ -19,14 +19,7 @@ include "mlir/Dialect/DXSA/IR/DXSAOpBase.td" // DXSA shared base for atomic memory ops //===----------------------------------------------------------------------===// -class DXSA_AtomicOp : DXSA_Op { - let arguments = (ins - DXSA_DstOperandAttr:$dst, - DXSA_SrcOperandAttr:$dst_address, - DXSA_SrcOperandAttr:$src0); - let results = (outs); - let assemblyFormat = "$dst `,` $dst_address `,` $src0 attr-dict"; -} +class DXSA_AtomicOp : DXSA_BaseOp; //===----------------------------------------------------------------------===// // dxsa.atomic_and @@ -209,16 +202,7 @@ def DXSA_AtomicUMin : DXSA_AtomicOp<"atomic_umin"> { // DXSA shared base for immediate atomic ops returning the prior value //===----------------------------------------------------------------------===// -class DXSA_ImmAtomicBinaryOp : DXSA_Op { - let arguments = (ins - DXSA_DstOperandAttr:$dst0, - DXSA_DstOperandAttr:$dst1, - DXSA_SrcOperandAttr:$dst_address, - DXSA_SrcOperandAttr:$src0); - let results = (outs); - let assemblyFormat = - "$dst0 `,` $dst1 `,` $dst_address `,` $src0 attr-dict"; -} +class DXSA_ImmAtomicBinaryOp : DXSA_BaseOp; //===----------------------------------------------------------------------===// // dxsa.imm_atomic_iadd @@ -418,7 +402,7 @@ def DXSA_ImmAtomicUMin : DXSA_ImmAtomicBinaryOp<"imm_atomic_umin"> { // dxsa.atomic_cmp_store //===----------------------------------------------------------------------===// -def DXSA_AtomicCmpStore : DXSA_Op<"atomic_cmp_store"> { +def DXSA_AtomicCmpStore : DXSA_BaseOp<"atomic_cmp_store", ["dst"], ["dst_address", "src0", "src1"]> { let summary = "atomic compare and conditional write to memory"; let description = [{ The `dxsa.atomic_cmp_store` operation atomically compares `$src0` with the @@ -437,14 +421,6 @@ def DXSA_AtomicCmpStore : DXSA_Op<"atomic_cmp_store"> { dxsa.atomic_cmp_store u<0>, r<1>, r<2>, r<3> ``` }]; - let arguments = (ins - DXSA_DstOperandAttr:$dst, - DXSA_SrcOperandAttr:$dst_address, - DXSA_SrcOperandAttr:$src0, - DXSA_SrcOperandAttr:$src1); - let results = (outs); - let assemblyFormat = - "$dst `,` $dst_address `,` $src0 `,` $src1 attr-dict"; } //===----------------------------------------------------------------------===// @@ -474,7 +450,7 @@ def DXSA_ImmAtomicExch : DXSA_ImmAtomicBinaryOp<"imm_atomic_exch"> { // dxsa.imm_atomic_cmp_exch //===----------------------------------------------------------------------===// -def DXSA_ImmAtomicCmpExch : DXSA_Op<"imm_atomic_cmp_exch"> { +def DXSA_ImmAtomicCmpExch : DXSA_BaseOp<"imm_atomic_cmp_exch", ["dst0", "dst1"], ["dst_address", "src0", "src1"]> { let summary = "atomic compare and exchange to memory, returning the prior value"; let description = [{ @@ -494,28 +470,13 @@ def DXSA_ImmAtomicCmpExch : DXSA_Op<"imm_atomic_cmp_exch"> { dxsa.imm_atomic_cmp_exch r<0, >, u<0>, r<1>, r<2>, r<3> ``` }]; - let arguments = (ins - DXSA_DstOperandAttr:$dst0, - DXSA_DstOperandAttr:$dst1, - DXSA_SrcOperandAttr:$dst_address, - DXSA_SrcOperandAttr:$src0, - DXSA_SrcOperandAttr:$src1); - let results = (outs); - let assemblyFormat = - "$dst0 `,` $dst1 `,` $dst_address `,` $src0 `,` $src1 attr-dict"; } //===----------------------------------------------------------------------===// // DXSA shared base for UnorderedAccessView (UAV) counter atomics //===----------------------------------------------------------------------===// -class DXSA_ImmAtomicCounterOp : DXSA_Op { - let arguments = (ins - DXSA_DstOperandAttr:$dst, - DXSA_SrcOperandAttr:$uav); - let results = (outs); - let assemblyFormat = "$dst `,` $uav attr-dict"; -} +class DXSA_ImmAtomicCounterOp : DXSA_BaseOp; //===----------------------------------------------------------------------===// // dxsa.imm_atomic_alloc diff --git a/mlir/include/mlir/Dialect/DXSA/IR/DXSABitwiseOps.td b/mlir/include/mlir/Dialect/DXSA/IR/DXSABitwiseOps.td index 0b726b372ba9..840940b860fd 100644 --- a/mlir/include/mlir/Dialect/DXSA/IR/DXSABitwiseOps.td +++ b/mlir/include/mlir/Dialect/DXSA/IR/DXSABitwiseOps.td @@ -38,7 +38,7 @@ def DXSA_And : DXSA_BinaryOp<"and"> { // dxsa.bfi //===----------------------------------------------------------------------===// -def DXSA_BFI : DXSA_PlainOp<"bfi", 1, 4> { +def DXSA_BFI : DXSA_BaseOp<"bfi", ["dst"], ["src0", "src1", "src2", "src3"]> { let summary = "bit field insert"; let description = [{ The `dxsa.bfi` operation takes a bit range from the LSB of a number diff --git a/mlir/include/mlir/Dialect/DXSA/IR/DXSAFPArithOps.td b/mlir/include/mlir/Dialect/DXSA/IR/DXSAFPArithOps.td index 2a0d3c18c54b..dd2b3f0f1010 100644 --- a/mlir/include/mlir/Dialect/DXSA/IR/DXSAFPArithOps.td +++ b/mlir/include/mlir/Dialect/DXSA/IR/DXSAFPArithOps.td @@ -354,17 +354,7 @@ def DXSA_LogSat : DXSA_UnaryOp<"log_sat"> { //===----------------------------------------------------------------------===// // Shared base for the multiply-add family: `$dst = $lhs * $rhs + $acc`. -class DXSA_MultiplyAddOp : DXSA_Op { - let arguments = (ins - DXSA_DstOperandAttr:$dst, - DXSA_SrcOperandAttr:$lhs, - DXSA_SrcOperandAttr:$rhs, - DXSA_SrcOperandAttr:$acc, - OptionalAttr:$precise); - let results = (outs); - let assemblyFormat = - "(`precise` $precise^)? $dst `,` $lhs `,` $rhs `,` $acc attr-dict"; -} +class DXSA_MultiplyAddOp : DXSA_BaseOp; def DXSA_Mad : DXSA_MultiplyAddOp<"mad"> { let summary = "component-wise floating-point multiply-add"; @@ -782,16 +772,7 @@ def DXSA_RsqSat : DXSA_UnaryOp<"rsq_sat"> { // Shared base for sine/cosine: sin of `$operand` into `$sin`, cosine into // `$cos`. Either destination may be `null` when that result is not needed. -class DXSA_SincosOp : DXSA_Op { - let arguments = (ins - DXSA_DstOperandAttr:$sin, - DXSA_DstOperandAttr:$cos, - DXSA_SrcOperandAttr:$operand, - OptionalAttr:$precise); - let results = (outs); - let assemblyFormat = - "(`precise` $precise^)? $sin `,` $cos `,` $operand attr-dict"; -} +class DXSA_SincosOp : DXSA_BaseOp; def DXSA_Sincos : DXSA_SincosOp<"sincos"> { let summary = "component-wise floating-point sine and cosine"; diff --git a/mlir/include/mlir/Dialect/DXSA/IR/DXSAOpBase.td b/mlir/include/mlir/Dialect/DXSA/IR/DXSAOpBase.td index 492ad424b0fa..5e05bbf01aac 100644 --- a/mlir/include/mlir/Dialect/DXSA/IR/DXSAOpBase.td +++ b/mlir/include/mlir/Dialect/DXSA/IR/DXSAOpBase.td @@ -25,35 +25,27 @@ class DXSA_Op traits = []> : Op; //===----------------------------------------------------------------------===// -// DXSA shared base for ops with inline operands +// DXSA shared base for ops with all operands specified inline //===----------------------------------------------------------------------===// -class DXSA_PlainOp : DXSA_Op { - defvar dstnames = !cond(!eq(dsts, 1): ["dst"], - true: !foreach(n, !range(dsts), !strconcat("dst", !cast(n)))); - defvar srcnames = !cond(!eq(srcs, 1): ["src"], - !eq(srcs, 2): ["lhs", "rhs"], - true: !foreach(n, !range(srcs), !strconcat("src", !cast(n)))); - defvar precisename = ["precise"]; - defvar argnames = !listconcat(dstnames, srcnames, precisename); - - defvar args = !foldl(!dag(ins, ?, argnames), - !range(!add(dsts, srcs, 1)), - acc, n, !setdagarg(acc, n, !cond( - !lt(n, dsts): DXSA_DstOperandAttr, - !lt(n, !add(dsts, srcs)): DXSA_SrcOperandAttr, - true: OptionalAttr))); - let arguments = args; +class DXSA_BaseOp dstnames, list srcnames> : DXSA_Op { + let arguments = !setdagarg(!dag(ins, + !listconcat(!listsplat(DXSA_DstOperandAttr, !size(dstnames)), + !listsplat(DXSA_SrcOperandAttr, !add(!size(srcnames), 1))), + !listconcat(dstnames, srcnames, ["precise"])), + !add(!size(dstnames), !size(srcnames)), OptionalAttr); let results = (outs); let assemblyFormat = !interleave(!listconcat( ["(`precise` $precise^)?"], - !tail(!cond(!eq(!add(dsts, srcs), 0): [""], true: !listflatten(!foreach(n, !listconcat(dstnames, srcnames), ["`,`", !strconcat("$", n)])))), + !tail(!cond( + !and(!empty(dstnames), !empty(srcnames)): [""], + true: !listflatten(!foreach(name, !listconcat(dstnames, srcnames), ["`,`", !strconcat("$", name)])))), ["attr-dict"]), " "); } -class DXSA_NullaryOp : DXSA_PlainOp; -class DXSA_UnaryOp : DXSA_PlainOp; -class DXSA_BinaryOp : DXSA_PlainOp; -class DXSA_TernaryOp : DXSA_PlainOp; +class DXSA_NullaryOp dstnames = ["dst"]> : DXSA_BaseOp; +class DXSA_UnaryOp dstnames = ["dst"]> : DXSA_BaseOp; +class DXSA_BinaryOp dstnames = ["dst"]> : DXSA_BaseOp; +class DXSA_TernaryOp dstnames = ["dst"]> : DXSA_BaseOp; #endif // MLIR_DIALECT_DXSA_IR_DXSAOPBASE diff --git a/mlir/lib/Target/DXSA/BinaryParser.cpp b/mlir/lib/Target/DXSA/BinaryParser.cpp index 07913483c8fb..161c4762955d 100644 --- a/mlir/lib/Target/DXSA/BinaryParser.cpp +++ b/mlir/lib/Target/DXSA/BinaryParser.cpp @@ -387,9 +387,6 @@ struct InstructionModifier { uint32_t saturate{0}; }; -// Whether an op carries a precise modifier attribute. -enum class HasPreciseAttr { No, Yes }; - struct OperandModifier { uint32_t modifier{0}; uint32_t minPrecision{0}; @@ -785,7 +782,7 @@ class DXBuilder { context, static_cast(preciseMask)); } - template Instruction buildOp(uint32_t preciseMask, Location loc, @@ -795,13 +792,8 @@ class DXBuilder { [&](auto... dstOperands) { return std::apply( [&](auto... srcOperands) -> Instruction { - if constexpr (HasPrecise == HasPreciseAttr::Yes) - return OpT::create(builder, loc, dstOperands..., - srcOperands..., - buildPreciseAttr(preciseMask)); - else - return OpT::create(builder, loc, dstOperands..., - srcOperands...); + return OpT::create(builder, loc, dstOperands..., srcOperands..., + buildPreciseAttr(preciseMask)); }, srcs); }, @@ -1674,8 +1666,8 @@ class Parser { return operands; } - template + template FailureOr decodeOp(size_t beginOffset, uint32_t length, const InstructionModifier &modifier, Location loc) { @@ -1689,10 +1681,8 @@ class Parser { return failure(); if constexpr (!std::is_same_v) if (modifier.saturate) - return builder.buildOp(modifier.preciseMask, loc, - *dsts, *srcs); - return builder.buildOp(modifier.preciseMask, loc, *dsts, - *srcs); + return builder.buildOp(modifier.preciseMask, loc, *dsts, *srcs); + return builder.buildOp(modifier.preciseMask, loc, *dsts, *srcs); } FailureOr parseDclInput(Location loc) { @@ -2294,191 +2284,190 @@ class Parser { unsigned numOperands = instrInfo[opcode].numOperands; -#define SATURABLE_OP(MNEMONIC, NUM_DST_OPERANDS, NUM_SRC_OPERANDS, \ - HAS_PRECISE) \ - decodeOp(beginOffset, instructionLengthInTokens, modifier, \ getLocation()) -#define PLAIN_OP(MNEMONIC, NUM_DST_OPERANDS, NUM_SRC_OPERANDS, HAS_PRECISE) \ - decodeOp(beginOffset, instructionLengthInTokens, modifier, \ getLocation()) switch (opcode) { // Floating-point arithmetic instructions case D3D10_SB_OPCODE_ADD: - return SATURABLE_OP(Add, 1, 2, HasPreciseAttr::Yes); + return SATURABLE_OP(Add, 1, 2); case D3D10_SB_OPCODE_DIV: - return SATURABLE_OP(Div, 1, 2, HasPreciseAttr::Yes); + return SATURABLE_OP(Div, 1, 2); case D3D10_SB_OPCODE_DP2: - return SATURABLE_OP(Dp2, 1, 2, HasPreciseAttr::Yes); + return SATURABLE_OP(Dp2, 1, 2); case D3D10_SB_OPCODE_DP3: - return SATURABLE_OP(Dp3, 1, 2, HasPreciseAttr::Yes); + return SATURABLE_OP(Dp3, 1, 2); case D3D10_SB_OPCODE_DP4: - return SATURABLE_OP(Dp4, 1, 2, HasPreciseAttr::Yes); + return SATURABLE_OP(Dp4, 1, 2); case D3D10_SB_OPCODE_EXP: - return SATURABLE_OP(Exp, 1, 1, HasPreciseAttr::Yes); + return SATURABLE_OP(Exp, 1, 1); case D3D10_SB_OPCODE_FRC: - return SATURABLE_OP(Frc, 1, 1, HasPreciseAttr::Yes); + return SATURABLE_OP(Frc, 1, 1); case D3D10_SB_OPCODE_LOG: - return SATURABLE_OP(Log, 1, 1, HasPreciseAttr::Yes); + return SATURABLE_OP(Log, 1, 1); case D3D10_SB_OPCODE_MAD: - return SATURABLE_OP(Mad, 1, 3, HasPreciseAttr::Yes); + return SATURABLE_OP(Mad, 1, 3); case D3D10_SB_OPCODE_MAX: - return SATURABLE_OP(Max, 1, 2, HasPreciseAttr::Yes); + return SATURABLE_OP(Max, 1, 2); case D3D10_SB_OPCODE_MIN: - return SATURABLE_OP(Min, 1, 2, HasPreciseAttr::Yes); + return SATURABLE_OP(Min, 1, 2); case D3D10_SB_OPCODE_MUL: - return SATURABLE_OP(Mul, 1, 2, HasPreciseAttr::Yes); + return SATURABLE_OP(Mul, 1, 2); case D3D11_SB_OPCODE_RCP: - return SATURABLE_OP(Rcp, 1, 1, HasPreciseAttr::Yes); + return SATURABLE_OP(Rcp, 1, 1); case D3D10_SB_OPCODE_ROUND_NE: - return SATURABLE_OP(RoundNe, 1, 1, HasPreciseAttr::Yes); + return SATURABLE_OP(RoundNe, 1, 1); case D3D10_SB_OPCODE_ROUND_NI: - return SATURABLE_OP(RoundNi, 1, 1, HasPreciseAttr::Yes); + return SATURABLE_OP(RoundNi, 1, 1); case D3D10_SB_OPCODE_ROUND_PI: - return SATURABLE_OP(RoundPi, 1, 1, HasPreciseAttr::Yes); + return SATURABLE_OP(RoundPi, 1, 1); case D3D10_SB_OPCODE_ROUND_Z: - return SATURABLE_OP(RoundZ, 1, 1, HasPreciseAttr::Yes); + return SATURABLE_OP(RoundZ, 1, 1); case D3D10_SB_OPCODE_RSQ: - return SATURABLE_OP(Rsq, 1, 1, HasPreciseAttr::Yes); + return SATURABLE_OP(Rsq, 1, 1); case D3D10_SB_OPCODE_SINCOS: - return SATURABLE_OP(Sincos, 2, 1, HasPreciseAttr::Yes); + return SATURABLE_OP(Sincos, 2, 1); case D3D10_SB_OPCODE_SQRT: - return SATURABLE_OP(Sqrt, 1, 1, HasPreciseAttr::Yes); + return SATURABLE_OP(Sqrt, 1, 1); // Type conversion instructions case D3D11_SB_OPCODE_DTOF: - return PLAIN_OP(DToF, 1, 1, HasPreciseAttr::Yes); + return PLAIN_OP(DToF, 1, 1); case D3D11_1_SB_OPCODE_DTOI: - return PLAIN_OP(DToI, 1, 1, HasPreciseAttr::Yes); + return PLAIN_OP(DToI, 1, 1); case D3D11_1_SB_OPCODE_DTOU: - return PLAIN_OP(DToU, 1, 1, HasPreciseAttr::Yes); + return PLAIN_OP(DToU, 1, 1); case D3D11_SB_OPCODE_F16TOF32: - return PLAIN_OP(F16ToF32, 1, 1, HasPreciseAttr::Yes); + return PLAIN_OP(F16ToF32, 1, 1); case D3D11_SB_OPCODE_F32TOF16: - return PLAIN_OP(F32ToF16, 1, 1, HasPreciseAttr::Yes); + return PLAIN_OP(F32ToF16, 1, 1); case D3D11_SB_OPCODE_FTOD: - return PLAIN_OP(FToD, 1, 1, HasPreciseAttr::Yes); + return PLAIN_OP(FToD, 1, 1); case D3D10_SB_OPCODE_FTOI: - return PLAIN_OP(FToI, 1, 1, HasPreciseAttr::Yes); + return PLAIN_OP(FToI, 1, 1); case D3D10_SB_OPCODE_FTOU: - return PLAIN_OP(FToU, 1, 1, HasPreciseAttr::Yes); + return PLAIN_OP(FToU, 1, 1); case D3D11_1_SB_OPCODE_ITOD: - return PLAIN_OP(IToD, 1, 1, HasPreciseAttr::Yes); + return PLAIN_OP(IToD, 1, 1); case D3D10_SB_OPCODE_ITOF: - return PLAIN_OP(IToF, 1, 1, HasPreciseAttr::Yes); + return PLAIN_OP(IToF, 1, 1); case D3D11_1_SB_OPCODE_UTOD: - return PLAIN_OP(UToD, 1, 1, HasPreciseAttr::Yes); + return PLAIN_OP(UToD, 1, 1); case D3D10_SB_OPCODE_UTOF: - return PLAIN_OP(UToF, 1, 1, HasPreciseAttr::Yes); + return PLAIN_OP(UToF, 1, 1); // Comparison instructions case D3D10_SB_OPCODE_EQ: - return PLAIN_OP(Eq, 1, 2, HasPreciseAttr::Yes); + return PLAIN_OP(Eq, 1, 2); case D3D10_SB_OPCODE_GE: - return PLAIN_OP(Ge, 1, 2, HasPreciseAttr::Yes); + return PLAIN_OP(Ge, 1, 2); case D3D10_SB_OPCODE_LT: - return PLAIN_OP(Lt, 1, 2, HasPreciseAttr::Yes); + return PLAIN_OP(Lt, 1, 2); case D3D10_SB_OPCODE_NE: - return PLAIN_OP(Ne, 1, 2, HasPreciseAttr::Yes); + return PLAIN_OP(Ne, 1, 2); case D3D10_SB_OPCODE_IEQ: - return PLAIN_OP(Ieq, 1, 2, HasPreciseAttr::Yes); + return PLAIN_OP(Ieq, 1, 2); case D3D10_SB_OPCODE_IGE: - return PLAIN_OP(Ige, 1, 2, HasPreciseAttr::Yes); + return PLAIN_OP(Ige, 1, 2); case D3D10_SB_OPCODE_ILT: - return PLAIN_OP(Ilt, 1, 2, HasPreciseAttr::Yes); + return PLAIN_OP(Ilt, 1, 2); case D3D10_SB_OPCODE_INE: - return PLAIN_OP(Ine, 1, 2, HasPreciseAttr::Yes); + return PLAIN_OP(Ine, 1, 2); case D3D10_SB_OPCODE_UGE: - return PLAIN_OP(Uge, 1, 2, HasPreciseAttr::Yes); + return PLAIN_OP(Uge, 1, 2); case D3D10_SB_OPCODE_ULT: - return PLAIN_OP(Ult, 1, 2, HasPreciseAttr::Yes); + return PLAIN_OP(Ult, 1, 2); // Integer arithmetic instructions case D3D10_SB_OPCODE_IADD: - return PLAIN_OP(IAdd, 1, 2, HasPreciseAttr::Yes); + return PLAIN_OP(IAdd, 1, 2); case D3D10_SB_OPCODE_IMAX: - return PLAIN_OP(IMax, 1, 2, HasPreciseAttr::Yes); + return PLAIN_OP(IMax, 1, 2); case D3D10_SB_OPCODE_IMIN: - return PLAIN_OP(IMin, 1, 2, HasPreciseAttr::Yes); + return PLAIN_OP(IMin, 1, 2); case D3D10_SB_OPCODE_INEG: - return PLAIN_OP(INeg, 1, 1, HasPreciseAttr::Yes); + return PLAIN_OP(INeg, 1, 1); case D3D10_SB_OPCODE_UMAX: - return PLAIN_OP(UMax, 1, 2, HasPreciseAttr::Yes); + return PLAIN_OP(UMax, 1, 2); case D3D10_SB_OPCODE_UMIN: - return PLAIN_OP(UMin, 1, 2, HasPreciseAttr::Yes); + return PLAIN_OP(UMin, 1, 2); // Bitwise instructions case D3D10_SB_OPCODE_AND: - return PLAIN_OP(And, 1, 2, HasPreciseAttr::Yes); + return PLAIN_OP(And, 1, 2); case D3D11_SB_OPCODE_BFI: - return PLAIN_OP(BFI, 1, 4, HasPreciseAttr::Yes); + return PLAIN_OP(BFI, 1, 4); case D3D11_SB_OPCODE_BFREV: - return PLAIN_OP(BFRev, 1, 1, HasPreciseAttr::Yes); + return PLAIN_OP(BFRev, 1, 1); case D3D11_SB_OPCODE_COUNTBITS: - return PLAIN_OP(CountBits, 1, 1, HasPreciseAttr::Yes); + return PLAIN_OP(CountBits, 1, 1); case D3D11_SB_OPCODE_FIRSTBIT_LO: - return PLAIN_OP(FirstBitLo, 1, 1, HasPreciseAttr::Yes); + return PLAIN_OP(FirstBitLo, 1, 1); case D3D11_SB_OPCODE_FIRSTBIT_HI: - return PLAIN_OP(FirstBitHi, 1, 1, HasPreciseAttr::Yes); + return PLAIN_OP(FirstBitHi, 1, 1); case D3D11_SB_OPCODE_FIRSTBIT_SHI: - return PLAIN_OP(FirstBitSHi, 1, 1, HasPreciseAttr::Yes); + return PLAIN_OP(FirstBitSHi, 1, 1); case D3D11_SB_OPCODE_IBFE: - return PLAIN_OP(IBFE, 1, 3, HasPreciseAttr::Yes); + return PLAIN_OP(IBFE, 1, 3); case D3D10_SB_OPCODE_ISHL: - return PLAIN_OP(IShl, 1, 2, HasPreciseAttr::Yes); + return PLAIN_OP(IShl, 1, 2); case D3D10_SB_OPCODE_ISHR: - return PLAIN_OP(IShr, 1, 2, HasPreciseAttr::Yes); + return PLAIN_OP(IShr, 1, 2); case D3D10_SB_OPCODE_NOT: - return PLAIN_OP(Not, 1, 1, HasPreciseAttr::Yes); + return PLAIN_OP(Not, 1, 1); case D3D10_SB_OPCODE_OR: - return PLAIN_OP(Or, 1, 2, HasPreciseAttr::Yes); + return PLAIN_OP(Or, 1, 2); case D3D11_SB_OPCODE_UBFE: - return PLAIN_OP(UBFE, 1, 3, HasPreciseAttr::Yes); + return PLAIN_OP(UBFE, 1, 3); case D3D10_SB_OPCODE_USHR: - return PLAIN_OP(UShr, 1, 2, HasPreciseAttr::Yes); + return PLAIN_OP(UShr, 1, 2); case D3D10_SB_OPCODE_XOR: - return PLAIN_OP(Xor, 1, 2, HasPreciseAttr::Yes); + return PLAIN_OP(Xor, 1, 2); // Atomic instructions case D3D11_SB_OPCODE_ATOMIC_AND: - return PLAIN_OP(AtomicAnd, 1, 2, HasPreciseAttr::No); + return PLAIN_OP(AtomicAnd, 1, 2); case D3D11_SB_OPCODE_ATOMIC_OR: - return PLAIN_OP(AtomicOr, 1, 2, HasPreciseAttr::No); + return PLAIN_OP(AtomicOr, 1, 2); case D3D11_SB_OPCODE_ATOMIC_XOR: - return PLAIN_OP(AtomicXor, 1, 2, HasPreciseAttr::No); + return PLAIN_OP(AtomicXor, 1, 2); case D3D11_SB_OPCODE_ATOMIC_IADD: - return PLAIN_OP(AtomicIAdd, 1, 2, HasPreciseAttr::No); + return PLAIN_OP(AtomicIAdd, 1, 2); case D3D11_SB_OPCODE_ATOMIC_IMAX: - return PLAIN_OP(AtomicIMax, 1, 2, HasPreciseAttr::No); + return PLAIN_OP(AtomicIMax, 1, 2); case D3D11_SB_OPCODE_ATOMIC_IMIN: - return PLAIN_OP(AtomicIMin, 1, 2, HasPreciseAttr::No); + return PLAIN_OP(AtomicIMin, 1, 2); case D3D11_SB_OPCODE_ATOMIC_UMAX: - return PLAIN_OP(AtomicUMax, 1, 2, HasPreciseAttr::No); + return PLAIN_OP(AtomicUMax, 1, 2); case D3D11_SB_OPCODE_ATOMIC_UMIN: - return PLAIN_OP(AtomicUMin, 1, 2, HasPreciseAttr::No); + return PLAIN_OP(AtomicUMin, 1, 2); case D3D11_SB_OPCODE_ATOMIC_CMP_STORE: - return PLAIN_OP(AtomicCmpStore, 1, 3, HasPreciseAttr::No); + return PLAIN_OP(AtomicCmpStore, 1, 3); case D3D11_SB_OPCODE_IMM_ATOMIC_IADD: - return PLAIN_OP(ImmAtomicIAdd, 2, 2, HasPreciseAttr::No); + return PLAIN_OP(ImmAtomicIAdd, 2, 2); case D3D11_SB_OPCODE_IMM_ATOMIC_AND: - return PLAIN_OP(ImmAtomicAnd, 2, 2, HasPreciseAttr::No); + return PLAIN_OP(ImmAtomicAnd, 2, 2); case D3D11_SB_OPCODE_IMM_ATOMIC_OR: - return PLAIN_OP(ImmAtomicOr, 2, 2, HasPreciseAttr::No); + return PLAIN_OP(ImmAtomicOr, 2, 2); case D3D11_SB_OPCODE_IMM_ATOMIC_XOR: - return PLAIN_OP(ImmAtomicXor, 2, 2, HasPreciseAttr::No); + return PLAIN_OP(ImmAtomicXor, 2, 2); case D3D11_SB_OPCODE_IMM_ATOMIC_EXCH: - return PLAIN_OP(ImmAtomicExch, 2, 2, HasPreciseAttr::No); + return PLAIN_OP(ImmAtomicExch, 2, 2); case D3D11_SB_OPCODE_IMM_ATOMIC_CMP_EXCH: - return PLAIN_OP(ImmAtomicCmpExch, 2, 3, HasPreciseAttr::No); + return PLAIN_OP(ImmAtomicCmpExch, 2, 3); case D3D11_SB_OPCODE_IMM_ATOMIC_IMAX: - return PLAIN_OP(ImmAtomicIMax, 2, 2, HasPreciseAttr::No); + return PLAIN_OP(ImmAtomicIMax, 2, 2); case D3D11_SB_OPCODE_IMM_ATOMIC_IMIN: - return PLAIN_OP(ImmAtomicIMin, 2, 2, HasPreciseAttr::No); + return PLAIN_OP(ImmAtomicIMin, 2, 2); case D3D11_SB_OPCODE_IMM_ATOMIC_UMAX: - return PLAIN_OP(ImmAtomicUMax, 2, 2, HasPreciseAttr::No); + return PLAIN_OP(ImmAtomicUMax, 2, 2); case D3D11_SB_OPCODE_IMM_ATOMIC_UMIN: - return PLAIN_OP(ImmAtomicUMin, 2, 2, HasPreciseAttr::No); + return PLAIN_OP(ImmAtomicUMin, 2, 2); case D3D11_SB_OPCODE_IMM_ATOMIC_ALLOC: - return PLAIN_OP(ImmAtomicAlloc, 1, 1, HasPreciseAttr::No); + return PLAIN_OP(ImmAtomicAlloc, 1, 1); case D3D11_SB_OPCODE_IMM_ATOMIC_CONSUME: - return PLAIN_OP(ImmAtomicConsume, 1, 1, HasPreciseAttr::No); + return PLAIN_OP(ImmAtomicConsume, 1, 1); } #undef SATURABLE_OP #undef PLAIN_OP From 844c25f9b3289d05433632eb331f9032d59fa5be Mon Sep 17 00:00:00 2001 From: Harald van Dijk Date: Tue, 30 Jun 2026 17:57:28 +0100 Subject: [PATCH 6/7] Apply Andrew's suggestion to use class fields rather than arguments --- .../mlir/Dialect/DXSA/IR/DXSAAtomicOps.td | 27 +++++++++--- .../mlir/Dialect/DXSA/IR/DXSABitwiseOps.td | 5 ++- .../mlir/Dialect/DXSA/IR/DXSAFPArithOps.td | 11 ++++- .../mlir/Dialect/DXSA/IR/DXSAOpBase.td | 43 ++++++++++++------- 4 files changed, 62 insertions(+), 24 deletions(-) diff --git a/mlir/include/mlir/Dialect/DXSA/IR/DXSAAtomicOps.td b/mlir/include/mlir/Dialect/DXSA/IR/DXSAAtomicOps.td index c3dbfa45883c..84701cb4cf37 100644 --- a/mlir/include/mlir/Dialect/DXSA/IR/DXSAAtomicOps.td +++ b/mlir/include/mlir/Dialect/DXSA/IR/DXSAAtomicOps.td @@ -19,7 +19,10 @@ include "mlir/Dialect/DXSA/IR/DXSAOpBase.td" // DXSA shared base for atomic memory ops //===----------------------------------------------------------------------===// -class DXSA_AtomicOp : DXSA_BaseOp; +class DXSA_AtomicOp : DXSA_BaseOp { + let srcs = (ins DXSA_SrcOperandAttr:$dst_address, DXSA_SrcOperandAttr:$src0); + let asmFormat = "$dst `,` $dst_address `,` $src0"; +} //===----------------------------------------------------------------------===// // dxsa.atomic_and @@ -202,7 +205,11 @@ def DXSA_AtomicUMin : DXSA_AtomicOp<"atomic_umin"> { // DXSA shared base for immediate atomic ops returning the prior value //===----------------------------------------------------------------------===// -class DXSA_ImmAtomicBinaryOp : DXSA_BaseOp; +class DXSA_ImmAtomicBinaryOp : DXSA_BaseOp { + let dsts = (ins DXSA_DstOperandAttr:$dst0, DXSA_DstOperandAttr:$dst1); + let srcs = (ins DXSA_SrcOperandAttr:$dst_address, DXSA_SrcOperandAttr:$src0); + let asmFormat = "$dst0 `,` $dst1 `,` $dst_address `,` $src0"; +} //===----------------------------------------------------------------------===// // dxsa.imm_atomic_iadd @@ -402,7 +409,7 @@ def DXSA_ImmAtomicUMin : DXSA_ImmAtomicBinaryOp<"imm_atomic_umin"> { // dxsa.atomic_cmp_store //===----------------------------------------------------------------------===// -def DXSA_AtomicCmpStore : DXSA_BaseOp<"atomic_cmp_store", ["dst"], ["dst_address", "src0", "src1"]> { +def DXSA_AtomicCmpStore : DXSA_BaseOp<"atomic_cmp_store"> { let summary = "atomic compare and conditional write to memory"; let description = [{ The `dxsa.atomic_cmp_store` operation atomically compares `$src0` with the @@ -421,6 +428,9 @@ def DXSA_AtomicCmpStore : DXSA_BaseOp<"atomic_cmp_store", ["dst"], ["dst_address dxsa.atomic_cmp_store u<0>, r<1>, r<2>, r<3> ``` }]; + let srcs = (ins DXSA_SrcOperandAttr:$dst_address, + DXSA_SrcOperandAttr:$src0, DXSA_SrcOperandAttr:$src1); + let asmFormat = "$dst `,` $dst_address `,` $src0 `,` $src1"; } //===----------------------------------------------------------------------===// @@ -450,7 +460,7 @@ def DXSA_ImmAtomicExch : DXSA_ImmAtomicBinaryOp<"imm_atomic_exch"> { // dxsa.imm_atomic_cmp_exch //===----------------------------------------------------------------------===// -def DXSA_ImmAtomicCmpExch : DXSA_BaseOp<"imm_atomic_cmp_exch", ["dst0", "dst1"], ["dst_address", "src0", "src1"]> { +def DXSA_ImmAtomicCmpExch : DXSA_BaseOp<"imm_atomic_cmp_exch"> { let summary = "atomic compare and exchange to memory, returning the prior value"; let description = [{ @@ -470,13 +480,20 @@ def DXSA_ImmAtomicCmpExch : DXSA_BaseOp<"imm_atomic_cmp_exch", ["dst0", "dst1"], dxsa.imm_atomic_cmp_exch r<0, >, u<0>, r<1>, r<2>, r<3> ``` }]; + let dsts = (ins DXSA_DstOperandAttr:$dst0, DXSA_DstOperandAttr:$dst1); + let srcs = (ins DXSA_SrcOperandAttr:$dst_address, + DXSA_SrcOperandAttr:$src0, DXSA_SrcOperandAttr:$src1); + let asmFormat = "$dst0 `,` $dst1 `,` $dst_address `,` $src0 `,` $src1"; } //===----------------------------------------------------------------------===// // DXSA shared base for UnorderedAccessView (UAV) counter atomics //===----------------------------------------------------------------------===// -class DXSA_ImmAtomicCounterOp : DXSA_BaseOp; +class DXSA_ImmAtomicCounterOp : DXSA_BaseOp { + let srcs = (ins DXSA_SrcOperandAttr:$uav); + let asmFormat = "$dst `,` $uav"; +} //===----------------------------------------------------------------------===// // dxsa.imm_atomic_alloc diff --git a/mlir/include/mlir/Dialect/DXSA/IR/DXSABitwiseOps.td b/mlir/include/mlir/Dialect/DXSA/IR/DXSABitwiseOps.td index 840940b860fd..fae017df560e 100644 --- a/mlir/include/mlir/Dialect/DXSA/IR/DXSABitwiseOps.td +++ b/mlir/include/mlir/Dialect/DXSA/IR/DXSABitwiseOps.td @@ -38,7 +38,7 @@ def DXSA_And : DXSA_BinaryOp<"and"> { // dxsa.bfi //===----------------------------------------------------------------------===// -def DXSA_BFI : DXSA_BaseOp<"bfi", ["dst"], ["src0", "src1", "src2", "src3"]> { +def DXSA_BFI : DXSA_BaseOp<"bfi"> { let summary = "bit field insert"; let description = [{ The `dxsa.bfi` operation takes a bit range from the LSB of a number @@ -53,6 +53,9 @@ def DXSA_BFI : DXSA_BaseOp<"bfi", ["dst"], ["src0", "src1", "src2", "src3"]> { dxsa.bfi r<0, >, l(0x1E), l(0x2), v<0, >, l(0x1) ``` }]; + let srcs = (ins DXSA_SrcOperandAttr:$src0, DXSA_SrcOperandAttr:$src1, + DXSA_SrcOperandAttr:$src2, DXSA_SrcOperandAttr:$src3); + let asmFormat = "$dst `,` $src0 `,` $src1 `,` $src2 `,` $src3"; } //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/DXSA/IR/DXSAFPArithOps.td b/mlir/include/mlir/Dialect/DXSA/IR/DXSAFPArithOps.td index dd2b3f0f1010..842acf0c6de4 100644 --- a/mlir/include/mlir/Dialect/DXSA/IR/DXSAFPArithOps.td +++ b/mlir/include/mlir/Dialect/DXSA/IR/DXSAFPArithOps.td @@ -354,7 +354,10 @@ def DXSA_LogSat : DXSA_UnaryOp<"log_sat"> { //===----------------------------------------------------------------------===// // Shared base for the multiply-add family: `$dst = $lhs * $rhs + $acc`. -class DXSA_MultiplyAddOp : DXSA_BaseOp; +class DXSA_MultiplyAddOp : DXSA_BaseOp { + let srcs = (ins DXSA_SrcOperandAttr:$lhs, DXSA_SrcOperandAttr:$rhs, DXSA_SrcOperandAttr:$acc); + let asmFormat = "$dst `,` $lhs `,` $rhs `,` $acc"; +} def DXSA_Mad : DXSA_MultiplyAddOp<"mad"> { let summary = "component-wise floating-point multiply-add"; @@ -772,7 +775,11 @@ def DXSA_RsqSat : DXSA_UnaryOp<"rsq_sat"> { // Shared base for sine/cosine: sin of `$operand` into `$sin`, cosine into // `$cos`. Either destination may be `null` when that result is not needed. -class DXSA_SincosOp : DXSA_BaseOp; +class DXSA_SincosOp : DXSA_BaseOp { + let dsts = (ins DXSA_DstOperandAttr:$sin, DXSA_DstOperandAttr:$cos); + let srcs = (ins DXSA_SrcOperandAttr:$operand); + let asmFormat = "$sin `,` $cos `,` $operand"; +} def DXSA_Sincos : DXSA_SincosOp<"sincos"> { let summary = "component-wise floating-point sine and cosine"; diff --git a/mlir/include/mlir/Dialect/DXSA/IR/DXSAOpBase.td b/mlir/include/mlir/Dialect/DXSA/IR/DXSAOpBase.td index 5e05bbf01aac..a68dfb1b9360 100644 --- a/mlir/include/mlir/Dialect/DXSA/IR/DXSAOpBase.td +++ b/mlir/include/mlir/Dialect/DXSA/IR/DXSAOpBase.td @@ -28,24 +28,35 @@ class DXSA_Op traits = []> : // DXSA shared base for ops with all operands specified inline //===----------------------------------------------------------------------===// -class DXSA_BaseOp dstnames, list srcnames> : DXSA_Op { - let arguments = !setdagarg(!dag(ins, - !listconcat(!listsplat(DXSA_DstOperandAttr, !size(dstnames)), - !listsplat(DXSA_SrcOperandAttr, !add(!size(srcnames), 1))), - !listconcat(dstnames, srcnames, ["precise"])), - !add(!size(dstnames), !size(srcnames)), OptionalAttr); +class DXSA_BaseOp : DXSA_Op { + dag dsts; + dag srcs; + string asmFormat; + let dsts = (ins DXSA_DstOperandAttr:$dst); + + let arguments = !con(dsts, srcs, (ins OptionalAttr:$precise)); let results = (outs); - let assemblyFormat = !interleave(!listconcat( - ["(`precise` $precise^)?"], - !tail(!cond( - !and(!empty(dstnames), !empty(srcnames)): [""], - true: !listflatten(!foreach(name, !listconcat(dstnames, srcnames), ["`,`", !strconcat("$", name)])))), - ["attr-dict"]), " "); + let assemblyFormat = !strconcat("(`precise` $precise^)? ", asmFormat, " attr-dict"); +} + +class DXSA_NullaryOp : DXSA_BaseOp { + let srcs = (ins); + let asmFormat = "$dst"; +} + +class DXSA_UnaryOp : DXSA_BaseOp { + let srcs = (ins DXSA_SrcOperandAttr:$src); + let asmFormat = "$dst `,` $src"; +} + +class DXSA_BinaryOp : DXSA_BaseOp { + let srcs = (ins DXSA_SrcOperandAttr:$lhs, DXSA_SrcOperandAttr:$rhs); + let asmFormat = "$dst `,` $lhs `,` $rhs"; } -class DXSA_NullaryOp dstnames = ["dst"]> : DXSA_BaseOp; -class DXSA_UnaryOp dstnames = ["dst"]> : DXSA_BaseOp; -class DXSA_BinaryOp dstnames = ["dst"]> : DXSA_BaseOp; -class DXSA_TernaryOp dstnames = ["dst"]> : DXSA_BaseOp; +class DXSA_TernaryOp : DXSA_BaseOp { + let srcs = (ins DXSA_SrcOperandAttr:$src0, DXSA_SrcOperandAttr:$src1, DXSA_SrcOperandAttr:$src2); + let asmFormat = "$dst `,` $src0 `,` $src1 `,` $src2"; +} #endif // MLIR_DIALECT_DXSA_IR_DXSAOPBASE From e481f44e17e94ccae7ef2025304f9b0df9972432 Mon Sep 17 00:00:00 2001 From: Harald van Dijk Date: Wed, 1 Jul 2026 01:55:07 +0100 Subject: [PATCH 7/7] Update imul too --- .../include/mlir/Dialect/DXSA/IR/DXSAIntArithOps.td | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/mlir/include/mlir/Dialect/DXSA/IR/DXSAIntArithOps.td b/mlir/include/mlir/Dialect/DXSA/IR/DXSAIntArithOps.td index 85c50da92690..48e5827732c3 100644 --- a/mlir/include/mlir/Dialect/DXSA/IR/DXSAIntArithOps.td +++ b/mlir/include/mlir/Dialect/DXSA/IR/DXSAIntArithOps.td @@ -160,7 +160,7 @@ def DXSA_Imad : DXSA_MultiplyAddOp<"imad"> { // dxsa.imul //===----------------------------------------------------------------------===// -def DXSA_Imul : DXSA_Op<"imul"> { +def DXSA_Imul : DXSA_BinaryOp<"imul"> { let summary = "component-wise integer multiply"; let description = [{ The `dxsa.imul` operation computes the component-wise product @@ -176,15 +176,8 @@ def DXSA_Imul : DXSA_Op<"imul"> { dxsa.imul r<7, >, r<3, >, r<3, >, r<4, > ``` }]; - let arguments = (ins - DXSA_DstOperandAttr:$dstHi, - DXSA_DstOperandAttr:$dstLo, - DXSA_SrcOperandAttr:$lhs, - DXSA_SrcOperandAttr:$rhs, - OptionalAttr:$precise); - let results = (outs); - let assemblyFormat = - "(`precise` $precise^)? $dstHi `,` $dstLo `,` $lhs `,` $rhs attr-dict"; + let dsts = (ins DXSA_DstOperandAttr:$dstHi, DXSA_DstOperandAttr:$dstLo); + let asmFormat = "$dstHi `,` $dstLo `,` $lhs `,` $rhs"; } //===----------------------------------------------------------------------===//