Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions include/spirv-tools/optimizer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,13 @@ class Pass;
struct DescriptorSetAndBinding;
} // namespace opt

enum class SSARewriteMode {
None,
All,
OpaqueOnly,
SpecialTypes,
};

// C++ interface for SPIR-V optimization functionalities. It wraps the context
// (including target environment and the corresponding SPIR-V grammar) and
// provides methods for registering optimization passes and optimizing.
Expand Down Expand Up @@ -125,6 +132,9 @@ class SPIRV_TOOLS_EXPORT Optimizer {
// interface are considered live and are not eliminated.
Optimizer& RegisterLegalizationPasses();
Optimizer& RegisterLegalizationPasses(bool preserve_interface);
Optimizer& RegisterLegalizationPasses(bool preserve_interface,
bool include_loop_unroll,
SSARewriteMode ssa_rewrite_mode);

// Register passes specified in the list of |flags|. Each flag must be a
// string of a form accepted by Optimizer::FlagHasValidForm().
Expand Down Expand Up @@ -645,11 +655,6 @@ Optimizer::PassToken CreateLoopPeelingPass();
// Works best after LICM and local multi store elimination pass.
Optimizer::PassToken CreateLoopUnswitchPass();

// Creates a pass to legalize multidimensional arrays for Vulkan.
// This pass will replace multidimensional arrays of resources with a single
// dimensional array. Combine-access-chains should be run before this pass.
Optimizer::PassToken CreateLegalizeMultidimArrayPass();

// Create global value numbering pass.
// This pass will look for instructions where the same value is computed on all
// paths leading to the instruction. Those instructions are deleted.
Expand Down Expand Up @@ -709,7 +714,8 @@ Optimizer::PassToken CreateLoopUnrollPass(bool fully_unroll, int factor = 0);
// operations on SSA IDs. This allows SSA optimizers to act on these variables.
// Only variables that are local to the function and of supported types are
// processed (see IsSSATargetVar for details).
Optimizer::PassToken CreateSSARewritePass();
Optimizer::PassToken CreateSSARewritePass(
SSARewriteMode mode = SSARewriteMode::All);

// Create pass to convert relaxed precision instructions to half precision.
// This pass converts as many relaxed float32 arithmetic operations to half as
Expand Down
23 changes: 23 additions & 0 deletions source/opt/local_single_store_elim_pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,18 @@ bool LocalSingleStoreElimPass::RewriteLoads(
else
stored_id = store_inst->GetSingleWordInOperand(kVariableInitIdInIdx);

const auto get_image_pointer_id = [this](uint32_t value_id) {
Instruction* value_inst = context()->get_def_use_mgr()->GetDef(value_id);
while (value_inst && value_inst->opcode() == spv::Op::OpCopyObject) {
value_id = value_inst->GetSingleWordInOperand(0);
value_inst = context()->get_def_use_mgr()->GetDef(value_id);
}
if (!value_inst || value_inst->opcode() != spv::Op::OpLoad) {
return uint32_t{0};
}
return value_inst->GetSingleWordInOperand(0);
};

*all_rewritten = true;
bool modified = false;
for (Instruction* use : uses) {
Expand All @@ -319,6 +331,17 @@ bool LocalSingleStoreElimPass::RewriteLoads(
context()->KillNamesAndDecorates(use->result_id());
context()->ReplaceAllUsesWith(use->result_id(), stored_id);
context()->KillInst(use);
} else if (use->opcode() == spv::Op::OpImageTexelPointer &&
dominator_analysis->Dominates(store_inst, use)) {
const uint32_t image_ptr_id = get_image_pointer_id(stored_id);
if (image_ptr_id == 0) {
*all_rewritten = false;
continue;
}
modified = true;
context()->ForgetUses(use);
use->SetInOperand(0, {image_ptr_id});
context()->AnalyzeUses(use);
} else {
*all_rewritten = false;
}
Expand Down
40 changes: 27 additions & 13 deletions source/opt/mem_pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,27 @@ bool MemPass::IsBaseTargetType(const Instruction* typeInst) const {
}

bool MemPass::IsTargetType(const Instruction* typeInst) const {
if (IsBaseTargetType(typeInst)) return true;
switch (ssa_rewrite_mode_) {
case SSARewriteMode::None:
return false;
case SSARewriteMode::OpaqueOnly:
if (typeInst->IsOpaqueType()) return true;
break;
case SSARewriteMode::SpecialTypes:
if (typeInst->IsOpaqueType()) return true;
switch (typeInst->opcode()) {
case spv::Op::OpTypePointer:
case spv::Op::OpTypeCooperativeMatrixNV:
case spv::Op::OpTypeCooperativeMatrixKHR:
return true;
default:
break;
}
break;
case SSARewriteMode::All:
if (IsBaseTargetType(typeInst)) return true;
break;
}
if (typeInst->opcode() == spv::Op::OpTypeArray) {
if (!IsTargetType(
get_def_use_mgr()->GetDef(typeInst->GetSingleWordOperand(1)))) {
Expand All @@ -72,8 +92,7 @@ bool MemPass::IsTargetType(const Instruction* typeInst) const {

bool MemPass::IsNonPtrAccessChain(const spv::Op opcode) const {
return opcode == spv::Op::OpAccessChain ||
opcode == spv::Op::OpInBoundsAccessChain ||
opcode == spv::Op::OpUntypedAccessChainKHR;
opcode == spv::Op::OpInBoundsAccessChain;
}

bool MemPass::IsPtr(uint32_t ptrId) {
Expand All @@ -89,14 +108,11 @@ bool MemPass::IsPtr(uint32_t ptrId) {
ptrInst = get_def_use_mgr()->GetDef(varId);
}
const spv::Op op = ptrInst->opcode();
if (op == spv::Op::OpVariable || op == spv::Op::OpUntypedVariableKHR ||
IsNonPtrAccessChain(op))
return true;
if (op == spv::Op::OpVariable || IsNonPtrAccessChain(op)) return true;
const uint32_t varTypeId = ptrInst->type_id();
if (varTypeId == 0) return false;
const Instruction* varTypeInst = get_def_use_mgr()->GetDef(varTypeId);
return varTypeInst->opcode() == spv::Op::OpTypePointer ||
varTypeInst->opcode() == spv::Op::OpTypeUntypedPointerKHR;
return varTypeInst->opcode() == spv::Op::OpTypePointer;
}

Instruction* MemPass::GetPtr(uint32_t ptrId, uint32_t* varId) {
Expand All @@ -106,13 +122,11 @@ Instruction* MemPass::GetPtr(uint32_t ptrId, uint32_t* varId) {

switch (ptrInst->opcode()) {
case spv::Op::OpVariable:
case spv::Op::OpUntypedVariableKHR:
case spv::Op::OpFunctionParameter:
varInst = ptrInst;
break;
case spv::Op::OpAccessChain:
case spv::Op::OpInBoundsAccessChain:
case spv::Op::OpUntypedAccessChainKHR:
case spv::Op::OpPtrAccessChain:
case spv::Op::OpInBoundsPtrAccessChain:
case spv::Op::OpImageTexelPointer:
Expand All @@ -125,8 +139,7 @@ Instruction* MemPass::GetPtr(uint32_t ptrId, uint32_t* varId) {
break;
}

if (varInst->opcode() == spv::Op::OpVariable ||
varInst->opcode() == spv::Op::OpUntypedVariableKHR) {
if (varInst->opcode() == spv::Op::OpVariable) {
*varId = varInst->result_id();
} else {
*varId = 0;
Expand Down Expand Up @@ -241,7 +254,8 @@ void MemPass::DCEInst(Instruction* inst,
}
}

MemPass::MemPass() {}
MemPass::MemPass(SSARewriteMode ssa_rewrite_mode)
: ssa_rewrite_mode_(ssa_rewrite_mode) {}

bool MemPass::HasOnlySupportedRefs(uint32_t varId) {
return get_def_use_mgr()->WhileEachUser(varId, [this](Instruction* user) {
Expand Down
7 changes: 5 additions & 2 deletions source/opt/mem_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include <unordered_set>
#include <utility>

#include "spirv-tools/optimizer.hpp"
#include "source/opt/basic_block.h"
#include "source/opt/def_use_manager.h"
#include "source/opt/dominator_analysis.h"
Expand Down Expand Up @@ -68,7 +69,7 @@ class MemPass : public Pass {
void CollectTargetVars(Function* func);

protected:
MemPass();
explicit MemPass(SSARewriteMode ssa_rewrite_mode = SSARewriteMode::All);

// Returns true if |typeInst| is a scalar type
// or a vector or matrix
Expand Down Expand Up @@ -133,7 +134,9 @@ class MemPass : public Pass {
// Cache of verified non-target vars
std::unordered_set<uint32_t> seen_non_target_vars_;

private:
private:
SSARewriteMode ssa_rewrite_mode_ = SSARewriteMode::All;

// Return true if all uses of |varId| are only through supported reference
// operations ie. loads and store. Also cache in supported_ref_vars_.
// TODO(dnovillo): This function is replicated in other passes and it's
Expand Down
Loading
Loading