Shader_IR: Implement Fast BRX and allow multi-branches in the CFG.

This commit is contained in:
Fernando Sahmkow 2019-09-23 22:55:25 -04:00 committed by FernandoS27
parent acd6441134
commit 8909f52166
7 changed files with 260 additions and 132 deletions

View File

@ -2338,6 +2338,11 @@ public:
inner += expr.value ? "true" : "false"; inner += expr.value ? "true" : "false";
} }
void operator()(VideoCommon::Shader::ExprGprEqual& expr) {
inner +=
"( ftou(" + decomp.GetRegister(expr.gpr) + ") == " + std::to_string(expr.value) + ')';
}
const std::string& GetResult() const { const std::string& GetResult() const {
return inner; return inner;
} }

View File

@ -1704,6 +1704,13 @@ public:
return expr.value ? decomp.v_true : decomp.v_false; return expr.value ? decomp.v_true : decomp.v_false;
} }
Id operator()(const ExprGprEqual& expr) {
const Id target = decomp.Constant(decomp.t_uint, expr.value);
const Id gpr = decomp.BitcastTo<Type::Uint>(
decomp.Emit(decomp.OpLoad(decomp.t_float, decomp.registers.at(expr.gpr))));
return decomp.Emit(decomp.OpLogicalEqual(decomp.t_uint, gpr, target));
}
Id Visit(const Expr& node) { Id Visit(const Expr& node) {
return std::visit(*this, *node); return std::visit(*this, *node);
} }

View File

@ -228,6 +228,10 @@ public:
inner += expr.value ? "true" : "false"; inner += expr.value ? "true" : "false";
} }
void operator()(ExprGprEqual const& expr) {
inner += "( gpr_" + std::to_string(expr.gpr) + " == " + std::to_string(expr.value) + ')';
}
const std::string& GetResult() const { const std::string& GetResult() const {
return inner; return inner;
} }

View File

@ -35,14 +35,24 @@ struct BlockStack {
std::stack<u32> pbk_stack{}; std::stack<u32> pbk_stack{};
}; };
struct BlockBranchInfo { template <typename T, typename... Args>
Condition condition{}; BlockBranchInfo MakeBranchInfo(Args&&... args) {
s32 address{exit_branch}; static_assert(std::is_convertible_v<T, BranchData>);
bool kill{}; return std::make_shared<BranchData>(T(std::forward<Args>(args)...));
bool is_sync{}; }
bool is_brk{};
bool ignore{}; bool BlockBranchInfoAreEqual(BlockBranchInfo first, BlockBranchInfo second) {
}; return false; //(*first) == (*second);
}
bool BlockBranchIsIgnored(BlockBranchInfo first) {
bool ignore = false;
if (std::holds_alternative<SingleBranch>(*first)) {
auto branch = std::get_if<SingleBranch>(first.get());
ignore = branch->ignore;
}
return ignore;
}
struct BlockInfo { struct BlockInfo {
u32 start{}; u32 start{};
@ -234,6 +244,7 @@ std::pair<ParseResult, ParseInfo> ParseCode(CFGRebuildState& state, u32 address)
u32 offset = static_cast<u32>(address); u32 offset = static_cast<u32>(address);
const u32 end_address = static_cast<u32>(state.program_size / sizeof(Instruction)); const u32 end_address = static_cast<u32>(state.program_size / sizeof(Instruction));
ParseInfo parse_info{}; ParseInfo parse_info{};
SingleBranch single_branch{};
const auto insert_label = [](CFGRebuildState& state, u32 address) { const auto insert_label = [](CFGRebuildState& state, u32 address) {
const auto pair = state.labels.emplace(address); const auto pair = state.labels.emplace(address);
@ -246,13 +257,14 @@ std::pair<ParseResult, ParseInfo> ParseCode(CFGRebuildState& state, u32 address)
if (offset >= end_address) { if (offset >= end_address) {
// ASSERT_OR_EXECUTE can't be used, as it ignores the break // ASSERT_OR_EXECUTE can't be used, as it ignores the break
ASSERT_MSG(false, "Shader passed the current limit!"); ASSERT_MSG(false, "Shader passed the current limit!");
parse_info.branch_info.address = exit_branch;
parse_info.branch_info.ignore = false; single_branch.address = exit_branch;
single_branch.ignore = false;
break; break;
} }
if (state.registered.count(offset) != 0) { if (state.registered.count(offset) != 0) {
parse_info.branch_info.address = offset; single_branch.address = offset;
parse_info.branch_info.ignore = true; single_branch.ignore = true;
break; break;
} }
if (IsSchedInstruction(offset, state.start)) { if (IsSchedInstruction(offset, state.start)) {
@ -269,24 +281,26 @@ std::pair<ParseResult, ParseInfo> ParseCode(CFGRebuildState& state, u32 address)
switch (opcode->get().GetId()) { switch (opcode->get().GetId()) {
case OpCode::Id::EXIT: { case OpCode::Id::EXIT: {
const auto pred_index = static_cast<u32>(instr.pred.pred_index); const auto pred_index = static_cast<u32>(instr.pred.pred_index);
parse_info.branch_info.condition.predicate = single_branch.condition.predicate = GetPredicate(pred_index, instr.negate_pred != 0);
GetPredicate(pred_index, instr.negate_pred != 0); if (single_branch.condition.predicate == Pred::NeverExecute) {
if (parse_info.branch_info.condition.predicate == Pred::NeverExecute) {
offset++; offset++;
continue; continue;
} }
const ConditionCode cc = instr.flow_condition_code; const ConditionCode cc = instr.flow_condition_code;
parse_info.branch_info.condition.cc = cc; single_branch.condition.cc = cc;
if (cc == ConditionCode::F) { if (cc == ConditionCode::F) {
offset++; offset++;
continue; continue;
} }
parse_info.branch_info.address = exit_branch; single_branch.address = exit_branch;
parse_info.branch_info.kill = false; single_branch.kill = false;
parse_info.branch_info.is_sync = false; single_branch.is_sync = false;
parse_info.branch_info.is_brk = false; single_branch.is_brk = false;
parse_info.branch_info.ignore = false; single_branch.ignore = false;
parse_info.end_address = offset; parse_info.end_address = offset;
parse_info.branch_info = MakeBranchInfo<SingleBranch>(
single_branch.condition, single_branch.address, single_branch.kill,
single_branch.is_sync, single_branch.is_brk, single_branch.ignore);
return {ParseResult::ControlCaught, parse_info}; return {ParseResult::ControlCaught, parse_info};
} }
@ -295,99 +309,107 @@ std::pair<ParseResult, ParseInfo> ParseCode(CFGRebuildState& state, u32 address)
return {ParseResult::AbnormalFlow, parse_info}; return {ParseResult::AbnormalFlow, parse_info};
} }
const auto pred_index = static_cast<u32>(instr.pred.pred_index); const auto pred_index = static_cast<u32>(instr.pred.pred_index);
parse_info.branch_info.condition.predicate = single_branch.condition.predicate = GetPredicate(pred_index, instr.negate_pred != 0);
GetPredicate(pred_index, instr.negate_pred != 0); if (single_branch.condition.predicate == Pred::NeverExecute) {
if (parse_info.branch_info.condition.predicate == Pred::NeverExecute) {
offset++; offset++;
continue; continue;
} }
const ConditionCode cc = instr.flow_condition_code; const ConditionCode cc = instr.flow_condition_code;
parse_info.branch_info.condition.cc = cc; single_branch.condition.cc = cc;
if (cc == ConditionCode::F) { if (cc == ConditionCode::F) {
offset++; offset++;
continue; continue;
} }
const u32 branch_offset = offset + instr.bra.GetBranchTarget(); const u32 branch_offset = offset + instr.bra.GetBranchTarget();
if (branch_offset == 0) { if (branch_offset == 0) {
parse_info.branch_info.address = exit_branch; single_branch.address = exit_branch;
} else { } else {
parse_info.branch_info.address = branch_offset; single_branch.address = branch_offset;
} }
insert_label(state, branch_offset); insert_label(state, branch_offset);
parse_info.branch_info.kill = false; single_branch.kill = false;
parse_info.branch_info.is_sync = false; single_branch.is_sync = false;
parse_info.branch_info.is_brk = false; single_branch.is_brk = false;
parse_info.branch_info.ignore = false; single_branch.ignore = false;
parse_info.end_address = offset; parse_info.end_address = offset;
parse_info.branch_info = MakeBranchInfo<SingleBranch>(
single_branch.condition, single_branch.address, single_branch.kill,
single_branch.is_sync, single_branch.is_brk, single_branch.ignore);
return {ParseResult::ControlCaught, parse_info}; return {ParseResult::ControlCaught, parse_info};
} }
case OpCode::Id::SYNC: { case OpCode::Id::SYNC: {
const auto pred_index = static_cast<u32>(instr.pred.pred_index); const auto pred_index = static_cast<u32>(instr.pred.pred_index);
parse_info.branch_info.condition.predicate = single_branch.condition.predicate = GetPredicate(pred_index, instr.negate_pred != 0);
GetPredicate(pred_index, instr.negate_pred != 0); if (single_branch.condition.predicate == Pred::NeverExecute) {
if (parse_info.branch_info.condition.predicate == Pred::NeverExecute) {
offset++; offset++;
continue; continue;
} }
const ConditionCode cc = instr.flow_condition_code; const ConditionCode cc = instr.flow_condition_code;
parse_info.branch_info.condition.cc = cc; single_branch.condition.cc = cc;
if (cc == ConditionCode::F) { if (cc == ConditionCode::F) {
offset++; offset++;
continue; continue;
} }
parse_info.branch_info.address = unassigned_branch; single_branch.address = unassigned_branch;
parse_info.branch_info.kill = false; single_branch.kill = false;
parse_info.branch_info.is_sync = true; single_branch.is_sync = true;
parse_info.branch_info.is_brk = false; single_branch.is_brk = false;
parse_info.branch_info.ignore = false; single_branch.ignore = false;
parse_info.end_address = offset; parse_info.end_address = offset;
parse_info.branch_info = MakeBranchInfo<SingleBranch>(
single_branch.condition, single_branch.address, single_branch.kill,
single_branch.is_sync, single_branch.is_brk, single_branch.ignore);
return {ParseResult::ControlCaught, parse_info}; return {ParseResult::ControlCaught, parse_info};
} }
case OpCode::Id::BRK: { case OpCode::Id::BRK: {
const auto pred_index = static_cast<u32>(instr.pred.pred_index); const auto pred_index = static_cast<u32>(instr.pred.pred_index);
parse_info.branch_info.condition.predicate = single_branch.condition.predicate = GetPredicate(pred_index, instr.negate_pred != 0);
GetPredicate(pred_index, instr.negate_pred != 0); if (single_branch.condition.predicate == Pred::NeverExecute) {
if (parse_info.branch_info.condition.predicate == Pred::NeverExecute) {
offset++; offset++;
continue; continue;
} }
const ConditionCode cc = instr.flow_condition_code; const ConditionCode cc = instr.flow_condition_code;
parse_info.branch_info.condition.cc = cc; single_branch.condition.cc = cc;
if (cc == ConditionCode::F) { if (cc == ConditionCode::F) {
offset++; offset++;
continue; continue;
} }
parse_info.branch_info.address = unassigned_branch; single_branch.address = unassigned_branch;
parse_info.branch_info.kill = false; single_branch.kill = false;
parse_info.branch_info.is_sync = false; single_branch.is_sync = false;
parse_info.branch_info.is_brk = true; single_branch.is_brk = true;
parse_info.branch_info.ignore = false; single_branch.ignore = false;
parse_info.end_address = offset; parse_info.end_address = offset;
parse_info.branch_info = MakeBranchInfo<SingleBranch>(
single_branch.condition, single_branch.address, single_branch.kill,
single_branch.is_sync, single_branch.is_brk, single_branch.ignore);
return {ParseResult::ControlCaught, parse_info}; return {ParseResult::ControlCaught, parse_info};
} }
case OpCode::Id::KIL: { case OpCode::Id::KIL: {
const auto pred_index = static_cast<u32>(instr.pred.pred_index); const auto pred_index = static_cast<u32>(instr.pred.pred_index);
parse_info.branch_info.condition.predicate = single_branch.condition.predicate = GetPredicate(pred_index, instr.negate_pred != 0);
GetPredicate(pred_index, instr.negate_pred != 0); if (single_branch.condition.predicate == Pred::NeverExecute) {
if (parse_info.branch_info.condition.predicate == Pred::NeverExecute) {
offset++; offset++;
continue; continue;
} }
const ConditionCode cc = instr.flow_condition_code; const ConditionCode cc = instr.flow_condition_code;
parse_info.branch_info.condition.cc = cc; single_branch.condition.cc = cc;
if (cc == ConditionCode::F) { if (cc == ConditionCode::F) {
offset++; offset++;
continue; continue;
} }
parse_info.branch_info.address = exit_branch; single_branch.address = exit_branch;
parse_info.branch_info.kill = true; single_branch.kill = true;
parse_info.branch_info.is_sync = false; single_branch.is_sync = false;
parse_info.branch_info.is_brk = false; single_branch.is_brk = false;
parse_info.branch_info.ignore = false; single_branch.ignore = false;
parse_info.end_address = offset; parse_info.end_address = offset;
parse_info.branch_info = MakeBranchInfo<SingleBranch>(
single_branch.condition, single_branch.address, single_branch.kill,
single_branch.is_sync, single_branch.is_brk, single_branch.ignore);
return {ParseResult::ControlCaught, parse_info}; return {ParseResult::ControlCaught, parse_info};
} }
@ -407,16 +429,25 @@ std::pair<ParseResult, ParseInfo> ParseCode(CFGRebuildState& state, u32 address)
auto tmp = TrackBranchIndirectInfo(state, address, offset); auto tmp = TrackBranchIndirectInfo(state, address, offset);
if (tmp) { if (tmp) {
auto result = *tmp; auto result = *tmp;
std::string entries{}; std::vector<CaseBranch> branches{};
s32 pc_target = offset + result.relative_position;
for (u32 i = 0; i < result.entries; i++) { for (u32 i = 0; i < result.entries; i++) {
auto k = locker.ObtainKey(result.buffer, result.offset + i * 4); auto k = state.locker.ObtainKey(result.buffer, result.offset + i * 4);
entries = entries + std::to_string(*k) + '\n'; if (!k) {
return {ParseResult::AbnormalFlow, parse_info};
}
u32 value = *k;
u32 target = static_cast<u32>((value >> 3) + pc_target);
insert_label(state, target);
branches.emplace_back(value, target);
} }
LOG_CRITICAL(HW_GPU, parse_info.end_address = offset;
"Track Successful, BRX: buffer:{}, offset:{}, entries:{}, inner:\n{}", parse_info.branch_info =
result.buffer, result.offset, result.entries, entries); MakeBranchInfo<MultiBranch>(static_cast<u32>(instr.gpr8.Value()), branches);
return {ParseResult::ControlCaught, parse_info};
} else { } else {
LOG_CRITICAL(HW_GPU, "Track Unsuccesful"); LOG_WARNING(HW_GPU, "BRX Track Unsuccesful");
} }
return {ParseResult::AbnormalFlow, parse_info}; return {ParseResult::AbnormalFlow, parse_info};
} }
@ -426,10 +457,13 @@ std::pair<ParseResult, ParseInfo> ParseCode(CFGRebuildState& state, u32 address)
offset++; offset++;
} }
parse_info.branch_info.kill = false; single_branch.kill = false;
parse_info.branch_info.is_sync = false; single_branch.is_sync = false;
parse_info.branch_info.is_brk = false; single_branch.is_brk = false;
parse_info.end_address = offset - 1; parse_info.end_address = offset - 1;
parse_info.branch_info = MakeBranchInfo<SingleBranch>(
single_branch.condition, single_branch.address, single_branch.kill, single_branch.is_sync,
single_branch.is_brk, single_branch.ignore);
return {ParseResult::BlockEnd, parse_info}; return {ParseResult::BlockEnd, parse_info};
} }
@ -453,9 +487,10 @@ bool TryInspectAddress(CFGRebuildState& state) {
BlockInfo& current_block = state.block_info[block_index]; BlockInfo& current_block = state.block_info[block_index];
current_block.end = address - 1; current_block.end = address - 1;
new_block.branch = current_block.branch; new_block.branch = current_block.branch;
BlockBranchInfo forward_branch{}; BlockBranchInfo forward_branch = MakeBranchInfo<SingleBranch>();
forward_branch.address = address; auto branch = std::get_if<SingleBranch>(forward_branch.get());
forward_branch.ignore = true; branch->address = address;
branch->ignore = true;
current_block.branch = forward_branch; current_block.branch = forward_branch;
return true; return true;
} }
@ -470,12 +505,15 @@ bool TryInspectAddress(CFGRebuildState& state) {
BlockInfo& block_info = CreateBlockInfo(state, address, parse_info.end_address); BlockInfo& block_info = CreateBlockInfo(state, address, parse_info.end_address);
block_info.branch = parse_info.branch_info; block_info.branch = parse_info.branch_info;
if (parse_info.branch_info.condition.IsUnconditional()) { if (std::holds_alternative<SingleBranch>(*block_info.branch)) {
auto branch = std::get_if<SingleBranch>(block_info.branch.get());
if (branch->condition.IsUnconditional()) {
return true;
}
const u32 fallthrough_address = parse_info.end_address + 1;
state.inspect_queries.push_front(fallthrough_address);
return true; return true;
} }
const u32 fallthrough_address = parse_info.end_address + 1;
state.inspect_queries.push_front(fallthrough_address);
return true; return true;
} }
@ -513,31 +551,41 @@ bool TryQuery(CFGRebuildState& state) {
state.queries.pop_front(); state.queries.pop_front();
gather_labels(q2.ssy_stack, state.ssy_labels, block); gather_labels(q2.ssy_stack, state.ssy_labels, block);
gather_labels(q2.pbk_stack, state.pbk_labels, block); gather_labels(q2.pbk_stack, state.pbk_labels, block);
if (!block.branch.condition.IsUnconditional()) { if (std::holds_alternative<SingleBranch>(*block.branch)) {
q2.address = block.end + 1; auto branch = std::get_if<SingleBranch>(block.branch.get());
state.queries.push_back(q2); if (!branch->condition.IsUnconditional()) {
} q2.address = block.end + 1;
state.queries.push_back(q2);
}
Query conditional_query{q2}; Query conditional_query{q2};
if (block.branch.is_sync) { if (branch->is_sync) {
if (block.branch.address == unassigned_branch) { if (branch->address == unassigned_branch) {
block.branch.address = conditional_query.ssy_stack.top(); branch->address = conditional_query.ssy_stack.top();
}
conditional_query.ssy_stack.pop();
} }
conditional_query.ssy_stack.pop(); if (branch->is_brk) {
} if (branch->address == unassigned_branch) {
if (block.branch.is_brk) { branch->address = conditional_query.pbk_stack.top();
if (block.branch.address == unassigned_branch) { }
block.branch.address = conditional_query.pbk_stack.top(); conditional_query.pbk_stack.pop();
} }
conditional_query.pbk_stack.pop(); conditional_query.address = branch->address;
state.queries.push_back(std::move(conditional_query));
return true;
}
auto multi_branch = std::get_if<MultiBranch>(block.branch.get());
for (auto& branch_case : multi_branch->branches) {
Query conditional_query{q2};
conditional_query.address = branch_case.address;
state.queries.push_back(std::move(conditional_query));
} }
conditional_query.address = block.branch.address;
state.queries.push_back(std::move(conditional_query));
return true; return true;
} }
} // Anonymous namespace } // Anonymous namespace
void InsertBranch(ASTManager& mm, const BlockBranchInfo& branch) { void InsertBranch(ASTManager& mm, const BlockBranchInfo& branch_info) {
const auto get_expr = ([&](const Condition& cond) -> Expr { const auto get_expr = ([&](const Condition& cond) -> Expr {
Expr result{}; Expr result{};
if (cond.cc != ConditionCode::T) { if (cond.cc != ConditionCode::T) {
@ -564,15 +612,24 @@ void InsertBranch(ASTManager& mm, const BlockBranchInfo& branch) {
} }
return MakeExpr<ExprBoolean>(true); return MakeExpr<ExprBoolean>(true);
}); });
if (branch.address < 0) { if (std::holds_alternative<SingleBranch>(*branch_info)) {
if (branch.kill) { auto branch = std::get_if<SingleBranch>(branch_info.get());
mm.InsertReturn(get_expr(branch.condition), true); if (branch->address < 0) {
if (branch->kill) {
mm.InsertReturn(get_expr(branch->condition), true);
return;
}
mm.InsertReturn(get_expr(branch->condition), false);
return; return;
} }
mm.InsertReturn(get_expr(branch.condition), false); mm.InsertGoto(get_expr(branch->condition), branch->address);
return; return;
} }
mm.InsertGoto(get_expr(branch.condition), branch.address); auto multi_branch = std::get_if<MultiBranch>(branch_info.get());
for (auto& branch_case : multi_branch->branches) {
mm.InsertGoto(MakeExpr<ExprGprEqual>(multi_branch->gpr, branch_case.cmp_value),
branch_case.address);
}
} }
void DecompileShader(CFGRebuildState& state) { void DecompileShader(CFGRebuildState& state) {
@ -584,9 +641,10 @@ void DecompileShader(CFGRebuildState& state) {
if (state.labels.count(block.start) != 0) { if (state.labels.count(block.start) != 0) {
state.manager->InsertLabel(block.start); state.manager->InsertLabel(block.start);
} }
u32 end = block.branch.ignore ? block.end + 1 : block.end; const bool ignore = BlockBranchIsIgnored(block.branch);
u32 end = ignore ? block.end + 1 : block.end;
state.manager->InsertBlock(block.start, end); state.manager->InsertBlock(block.start, end);
if (!block.branch.ignore) { if (!ignore) {
InsertBranch(*state.manager, block.branch); InsertBranch(*state.manager, block.branch);
} }
} }
@ -668,11 +726,9 @@ std::unique_ptr<ShaderCharacteristics> ScanFlow(const ProgramCode& program_code,
ShaderBlock new_block{}; ShaderBlock new_block{};
new_block.start = block.start; new_block.start = block.start;
new_block.end = block.end; new_block.end = block.end;
new_block.ignore_branch = block.branch.ignore; new_block.ignore_branch = BlockBranchIsIgnored(block.branch);
if (!new_block.ignore_branch) { if (!new_block.ignore_branch) {
new_block.branch.cond = block.branch.condition; new_block.branch = block.branch;
new_block.branch.kills = block.branch.kill;
new_block.branch.address = block.branch.address;
} }
result_out->end = std::max(result_out->end, block.end); result_out->end = std::max(result_out->end, block.end);
result_out->blocks.push_back(new_block); result_out->blocks.push_back(new_block);

View File

@ -7,6 +7,7 @@
#include <list> #include <list>
#include <optional> #include <optional>
#include <set> #include <set>
#include <variant>
#include "video_core/engines/shader_bytecode.h" #include "video_core/engines/shader_bytecode.h"
#include "video_core/shader/ast.h" #include "video_core/shader/ast.h"
@ -37,29 +38,57 @@ struct Condition {
} }
}; };
class SingleBranch {
public:
SingleBranch() = default;
SingleBranch(Condition condition, s32 address, bool kill, bool is_sync, bool is_brk,
bool ignore)
: condition{condition}, address{address}, kill{kill}, is_sync{is_sync}, is_brk{is_brk},
ignore{ignore} {}
bool operator==(const SingleBranch& b) const {
return std::tie(condition, address, kill, is_sync, is_brk, ignore) ==
std::tie(b.condition, b.address, b.kill, b.is_sync, b.is_brk, b.ignore);
}
Condition condition{};
s32 address{exit_branch};
bool kill{};
bool is_sync{};
bool is_brk{};
bool ignore{};
};
struct CaseBranch {
CaseBranch(u32 cmp_value, u32 address) : cmp_value{cmp_value}, address{address} {}
u32 cmp_value;
u32 address;
};
class MultiBranch {
public:
MultiBranch(u32 gpr, std::vector<CaseBranch>& branches)
: gpr{gpr}, branches{std::move(branches)} {}
u32 gpr{};
std::vector<CaseBranch> branches{};
};
using BranchData = std::variant<SingleBranch, MultiBranch>;
using BlockBranchInfo = std::shared_ptr<BranchData>;
bool BlockBranchInfoAreEqual(BlockBranchInfo first, BlockBranchInfo second);
struct ShaderBlock { struct ShaderBlock {
struct Branch {
Condition cond{};
bool kills{};
s32 address{};
bool operator==(const Branch& b) const {
return std::tie(cond, kills, address) == std::tie(b.cond, b.kills, b.address);
}
bool operator!=(const Branch& b) const {
return !operator==(b);
}
};
u32 start{}; u32 start{};
u32 end{}; u32 end{};
bool ignore_branch{}; bool ignore_branch{};
Branch branch{}; BlockBranchInfo branch{};
bool operator==(const ShaderBlock& sb) const { bool operator==(const ShaderBlock& sb) const {
return std::tie(start, end, ignore_branch, branch) == return std::tie(start, end, ignore_branch) ==
std::tie(sb.start, sb.end, sb.ignore_branch, sb.branch); std::tie(sb.start, sb.end, sb.ignore_branch) &&
BlockBranchInfoAreEqual(branch, sb.branch);
} }
bool operator!=(const ShaderBlock& sb) const { bool operator!=(const ShaderBlock& sb) const {

View File

@ -198,24 +198,38 @@ void ShaderIR::InsertControlFlow(NodeBlock& bb, const ShaderBlock& block) {
} }
return result; return result;
}; };
if (block.branch.address < 0) { if (std::holds_alternative<SingleBranch>(*block.branch)) {
if (block.branch.kills) { auto branch = std::get_if<SingleBranch>(block.branch.get());
Node n = Operation(OperationCode::Discard); if (branch->address < 0) {
n = apply_conditions(block.branch.cond, n); if (branch->kill) {
Node n = Operation(OperationCode::Discard);
n = apply_conditions(branch->condition, n);
bb.push_back(n);
global_code.push_back(n);
return;
}
Node n = Operation(OperationCode::Exit);
n = apply_conditions(branch->condition, n);
bb.push_back(n); bb.push_back(n);
global_code.push_back(n); global_code.push_back(n);
return; return;
} }
Node n = Operation(OperationCode::Exit); Node n = Operation(OperationCode::Branch, Immediate(branch->address));
n = apply_conditions(block.branch.cond, n); n = apply_conditions(branch->condition, n);
bb.push_back(n); bb.push_back(n);
global_code.push_back(n); global_code.push_back(n);
return; return;
} }
Node n = Operation(OperationCode::Branch, Immediate(block.branch.address)); auto multi_branch = std::get_if<MultiBranch>(block.branch.get());
n = apply_conditions(block.branch.cond, n); Node op_a = GetRegister(multi_branch->gpr);
bb.push_back(n); for (auto& branch_case : multi_branch->branches) {
global_code.push_back(n); Node n = Operation(OperationCode::Branch, Immediate(branch_case.address));
Node op_b = Immediate(branch_case.cmp_value);
Node condition = GetPredicateComparisonInteger(Tegra::Shader::PredCondition::Equal, false, op_a, op_b);
auto result = Conditional(condition, {n});
bb.push_back(result);
global_code.push_back(result);
}
} }
u32 ShaderIR::DecodeInstr(NodeBlock& bb, u32 pc) { u32 ShaderIR::DecodeInstr(NodeBlock& bb, u32 pc) {

View File

@ -17,13 +17,14 @@ using Tegra::Shader::Pred;
class ExprAnd; class ExprAnd;
class ExprBoolean; class ExprBoolean;
class ExprCondCode; class ExprCondCode;
class ExprGprEqual;
class ExprNot; class ExprNot;
class ExprOr; class ExprOr;
class ExprPredicate; class ExprPredicate;
class ExprVar; class ExprVar;
using ExprData = using ExprData = std::variant<ExprVar, ExprCondCode, ExprPredicate, ExprNot, ExprOr, ExprAnd,
std::variant<ExprVar, ExprCondCode, ExprPredicate, ExprNot, ExprOr, ExprAnd, ExprBoolean>; ExprBoolean, ExprGprEqual>;
using Expr = std::shared_ptr<ExprData>; using Expr = std::shared_ptr<ExprData>;
class ExprAnd final { class ExprAnd final {
@ -118,6 +119,18 @@ public:
bool value; bool value;
}; };
class ExprGprEqual final {
public:
ExprGprEqual(u32 gpr, u32 value) : gpr{gpr}, value{value} {}
bool operator==(const ExprGprEqual& b) const {
return gpr == b.gpr && value == b.value;
}
u32 gpr;
u32 value;
};
template <typename T, typename... Args> template <typename T, typename... Args>
Expr MakeExpr(Args&&... args) { Expr MakeExpr(Args&&... args) {
static_assert(std::is_convertible_v<T, ExprData>); static_assert(std::is_convertible_v<T, ExprData>);