Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/dxc/DXIL/DxilOperations.h
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ class OP {
static bool IsDxilOpFuncCallInst(const llvm::Instruction *I);
static bool IsDxilOpFuncCallInst(const llvm::Instruction *I, OpCode opcode);
static bool IsDxilOpWave(OpCode C);
static bool IsConvergentOp(OpCode C);
static bool IsDxilOpGradient(OpCode C);
static bool IsDxilOpFeedback(OpCode C);
static bool IsDxilOpBarrier(OpCode C);
Expand Down
16 changes: 16 additions & 0 deletions lib/DXIL/DxilOperations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@
// //
///////////////////////////////////////////////////////////////////////////////

// This file contains functions with generated code,
// and must be clang-format compliant, unless
// clang-format is turned off.

#include "dxc/DXIL/DxilOperations.h"
#include "dxc/DXIL/DxilConstants.h"
#include "dxc/DXIL/DxilInstructions.h"
Expand Down Expand Up @@ -3365,6 +3369,18 @@ bool OP::IsDxilOpWave(OpCode C) {
// OPCODE-WAVE:END
}

bool OP::IsConvergentOp(OpCode C) {
// This accounts for wave and quad ops
unsigned op = (unsigned)C;
if (OP::IsDxilOpWave(C))
return true;
// Account for derivative ops as well:
// DerivCoarseX = 83, DerivCoarseY = 84, DerivFineX = 85, DerivFineY = 86
if (83 <= op && op <= 86)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This generated pattern shouldn't be hand-written. You really want to include all gradient ops (derivative ops and ops that use derivatives). It turns out there's already a IsDxilOpGradient with the op set generated from hctdb.py. You can call that here instead.

return true;
return false;
}

bool OP::IsDxilOpGradient(OpCode C) {
unsigned op = (unsigned)C;
// clang-format off
Expand Down
20 changes: 20 additions & 0 deletions projects/dxilconv/include/DxilConvPasses/ScopeNestIterator.h
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,15 @@ class ScopeNestIterator {
}
}

bool IsIntermediate() {
switch (Kind) {
case BranchKind::SwitchFallthrough:
return true;
default:
return false;
}
}

// Translate a branch annoatation to the corresponding event type.
ScopeNestEvent::Type TranslateToNestType() {
switch (Kind) {
Expand All @@ -226,6 +235,8 @@ class ScopeNestIterator {
return ScopeNestEvent::Type::Switch_Begin;
case BranchKind::SwitchBreak:
return ScopeNestEvent::Type::Switch_Break;
case BranchKind::SwitchFallthrough:
return ScopeNestEvent::Type::Body;

case BranchKind::LoopBegin:
return ScopeNestEvent::Type::Loop_Begin;
Expand Down Expand Up @@ -868,6 +879,12 @@ class ScopeNestIterator {
MoveFromTopOfStack();
break;

// Fallthrough just continues in the current scope
case BranchKind::SwitchFallthrough:
SetCurrent(ScopeNestEvent::Type::Body, m_current.Block);
MoveToFirstSuccessor(); // continue exploring successors
break;

// Already exited an old scope.
case BranchKind::IfEnd:
case BranchKind::SwitchEnd:
Expand Down Expand Up @@ -914,6 +931,9 @@ class ScopeNestIterator {
if (BranchAnnotation annotation = BranchAnnotation::Read(B)) {
if (annotation.IsEndScope()) {
EnterEndOfScope(B, annotation.Get());
} else if (annotation.IsIntermediate()) {
// Just treat it as body but keep the annotation
SetCurrent(ScopeNestEvent::Type::Body, B);
} else {
DXASSERT_NOMSG(annotation.IsBeginScope());
StartNewScope(B, annotation.Get());
Expand Down
1 change: 1 addition & 0 deletions projects/dxilconv/include/DxilConvPasses/ScopeNestedCFG.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ enum class BranchKind {
SwitchEnd,
SwitchNoEnd,
SwitchBreak,
SwitchFallthrough,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This changes the subsequent Loop* values.

Since BranchKind is serialized to metadata, so this represents a breaking change in the output. While you could argue that anyone using this pass could synchronize updates to their code when consuming the new version of this pass, it would be a lot safer, and less destabilizing, to add the new value to the end instead.


LoopBegin,
LoopExit,
Expand Down
92 changes: 92 additions & 0 deletions projects/dxilconv/lib/DxilConvPasses/ScopeNestedCFG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
///////////////////////////////////////////////////////////////////////////////

#include "DxilConvPasses/ScopeNestedCFG.h"
#include "dxc/DXIL/DxilOperations.h"
#include "dxc/Support/Global.h"
#include "llvm/Analysis/ReducibilityAnalysis.h"

Expand Down Expand Up @@ -161,6 +162,8 @@ class ScopeNestedCFG : public FunctionPass {
bool IsAcyclicRegionTerminator(const BasicBlock *pBB);

BasicBlock *GetEffectiveNodeToFollowSuccessor(BasicBlock *pBB);
bool IsSwitchCaseBlock(BasicBlock *BB);
bool IsSwitchFallthrough(BasicBlock *Pred, BasicBlock *BB);
bool IsMergePoint(BasicBlock *pBB);

BasicBlock *SplitEdge(BasicBlock *pBB, unsigned SuccIdx, const Twine &Name,
Expand Down Expand Up @@ -646,6 +649,71 @@ BasicBlock *ScopeNestedCFG::GetEffectiveNodeToFollowSuccessor(BasicBlock *pBB) {
return pEffectiveSuccessor;
}

bool ScopeNestedCFG::IsSwitchCaseBlock(BasicBlock *BB) {
for (BasicBlock *Pred : predecessors(BB)) {
if (auto *SI = dyn_cast<SwitchInst>(Pred->getTerminator())) {
for (unsigned i = 0; i < SI->getNumSuccessors(); ++i) {
if (SI->getSuccessor(i) == BB)
return true;
}
}
}
return false;
}

bool ScopeNestedCFG::IsSwitchFallthrough(BasicBlock *Pred, BasicBlock *BB) {
// 1. Predecessor must NOT be the switch dispatch block
if (isa<SwitchInst>(Pred->getTerminator()))
return false;

// 2. Predecessor must end in unconditional branch
auto *Br = dyn_cast<BranchInst>(Pred->getTerminator());
if (!Br || !Br->isUnconditional())
return false;

// 3. BB must be reached by that unconditional branch
if (Br->getSuccessor(0) != BB)
return false;

// 4. Predecessor must be a switch case block
if (!IsSwitchCaseBlock(Pred))
return false;

// 5. Current block must be another switch case block
if (!IsSwitchCaseBlock(BB))
return false;

return true;
}

// Returns true if this basic block contains an instruction that
// is *control-flow convergence sensitive*.
//
// In DXIL, this includes:
// - Wave operations (WaveActive*, WaveReadLane*, etc.)
// - Quad / derivative operations (ddx, ddy, fwidth)
//
// Such instructions must NEVER be cloned or executed along
// structurally duplicated control-flow paths.
bool HasConvergentCall(BasicBlock *BB) {
for (Instruction &I : *BB) {
auto *CI = dyn_cast<CallInst>(&I);
if (!CI)
continue;

if (!hlsl::OP::IsDxilOpFuncCallInst(&I))
continue;

hlsl::OP::OpCode OpCode = hlsl::OP::GetDxilOpFuncCallInst(&I);

if (hlsl::OP::IsConvergentOp(OpCode)) {
return true;
}
}

return false;
}

bool ScopeNestedCFG::IsMergePoint(BasicBlock *pBB) {
unordered_set<BasicBlock *> UniquePredecessors;
for (auto itPred = pred_begin(pBB), endPred = pred_end(pBB);
Expand Down Expand Up @@ -1080,6 +1148,11 @@ void ScopeNestedCFG::DetermineScopeEndPoints(
pEndBB = BTO.GetBlock(MPI.MP);
}

if (HasConvergentCall(pBB)) {
// Force scope end point to be the block itself
pEndBB = pBB;
}

auto itOldEndPointBB = ScopeEndPoints.find(pBB);
if (itOldEndPointBB != ScopeEndPoints.end() &&
itOldEndPointBB->second != pEndBB) {
Expand Down Expand Up @@ -1255,6 +1328,13 @@ void ScopeNestedCFG::DetermineReachableMergePoints(
DXASSERT_NOMSG(m_LE2LBMap.find(pBB) == m_LE2LBMap.cend());
MPI.CandidateSet.insert(iBB);
}

if (HasConvergentCall(pBB)) {
// Force the merge point to be the block itself
MPI.MP = BTO.GetBlockId(pBB);
MPI.CandidateSet.clear();
MPI.CandidateSet.insert(MPI.MP);
}
}

// TODO: during final testing consider to remove.
Expand Down Expand Up @@ -1619,6 +1699,18 @@ void ScopeNestedCFG::TransformAcyclicRegion(BasicBlock *pEntry) {
//
BasicBlock *pSuccBB = pScopeBeginTI->getSuccessor(Scope.SuccIdx);

// Annotate all switch fallthrough branches
if (IsSwitchFallthrough(Scope.pScopeBeginBB, pSuccBB)) {
AnnotateBranch(Scope.pClonedScopeBeginBB, BranchKind::SwitchFallthrough);
}

// Only for convergent blocks, end the scope here
if (HasConvergentCall(Scope.pScopeBeginBB)) {
Scope.pScopeEndBB = Scope.pScopeBeginBB;
Scope.pClonedScopeEndBB = nullptr;
continue;
}

// 7. Already processed successor.
if (bIfScope || bSwitchScope) {
if (pSuccBB == Scope.pPrevSuccBB) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
; RUN: %opt-exe %s -scopenested -S | FileCheck %s

; verify that this pass won't identify fallthrough blocks
; as merge points and in turn won't duplicate blocks
; which contain convergent operations

; previously, blocks with wave operations would be cloned,
; violating the principle that wave operations should
; only be called by different threads when control flow
; is distinct between those threads

declare float @dx.op.waveActiveOp.f32(i32, float, i8, i8)

; CHECK-LABEL: define void @CSMain

define void @CSMain(i32 %tid, float %v) convergent {
entry:
switch i32 %tid, label %exit [
i32 0, label %case0
i32 1, label %case1
]

; CHECK: case0:
case0: ; switch case 0

; CHECK: call float @dx.op.waveActiveOp.f32(i32 119, float %v, i8 0, i8 0)
; CHECK: br label %case1, !dx.BranchKind ![[BK:.*]]
%w0 = call float @dx.op.waveActiveOp.f32(i32 119, float %v, i8 0, i8 0)
br label %case1 ; FALLTHROUGH


; CHECK: case1:
case1: ; switch case 1
%a = phi float [ %w0, %case0 ],
[ 0.0, %entry ]

; CHECK: call float @dx.op.waveActiveOp.f32(i32 119, float %v, i8 0, i8 0)
; CHECK: br label %exit, !dx.BranchKind ![[BK]]
%w1 = call float @dx.op.waveActiveOp.f32(i32 119, float %v, i8 0, i8 0)
%sum = fadd float %a, %w1
br label %exit

; no cloning, so there should be no more waveops after this point
; CHECK-NOT: call float @dx.waveActiveOp.f32(i32 119

exit:
%r = phi float [ %sum, %case1 ],
[ 0.0, %entry ]
ret void
}

; The BranchKind::SwitchFallthrough Kind ID is 8
; CHECK: ![[BK]] = !{i32 8}
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
; RUN: %opt-exe %s -scopenested -scopenestinfo -analyze -S | FileCheck %s

; verify that scope nest info looks correct for fallthrough cases
; we expect to see switch fallthrough treated as "Body", so
; it should not effect the scope indentation level.


; CHECK: @TopLevel_Begin
; CHECK: entry
; CHECK: @Switch_Begin
; CHECK: @Switch_Case
; CHECK: exit
; CHECK: @Switch_Break
; CHECK: @Switch_Case
; CHECK: case0
; CHECK: case1
; CHECK: exit
; CHECK: @Switch_Break
; CHECK: @Switch_Case
; CHECK: case1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Am I misreading this, or is case1 cloned here? Isn't that supposed to be prevented for blocks containing convergent ops? If this is a printing issue from scopenestinfo, shouldn't that pass be updated to properly express the CFG for fallthrough?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or maybe this is a block that wasn't cloned, but is being iterated twice by scope nested iterator? I'm not familiar enough with these and haven't loaded enough context to be sure. Either way, CHECK: case1 seems insufficient to differentiate case1 from something like case1.1 which could be produced by cloning.

; CHECK: exit
; CHECK: @Switch_Break
; CHECK: @Switch_End
; CHECK: @TopLevel_End

declare float @dx.op.waveActiveOp.f32(i32, float, i8, i8)


define void @CSMain(i32 %tid, float %v) convergent {
entry:
switch i32 %tid, label %exit [
i32 0, label %case0
i32 1, label %case1
]

case0: ; switch case 0

%w0 = call float @dx.op.waveActiveOp.f32(i32 119, float %v, i8 0, i8 0)
br label %case1 ; FALLTHROUGH


case1: ; switch case 1
%a = phi float [ %w0, %case0 ],
[ 0.0, %entry ]

%w1 = call float @dx.op.waveActiveOp.f32(i32 119, float %v, i8 0, i8 0)
%sum = fadd float %a, %w1
br label %exit

exit:
%r = phi float [ %sum, %case1 ],
[ 0.0, %entry ]
ret void
}