LLVM API Documentation
00001 //===- LowerSwitch.cpp - Eliminate Switch instructions --------------------===// 00002 // 00003 // The LLVM Compiler Infrastructure 00004 // 00005 // This file is distributed under the University of Illinois Open Source 00006 // License. See LICENSE.TXT for details. 00007 // 00008 //===----------------------------------------------------------------------===// 00009 // 00010 // The LowerSwitch transformation rewrites switch instructions with a sequence 00011 // of branches, which allows targets to get away with not implementing the 00012 // switch instruction until it is convenient. 00013 // 00014 //===----------------------------------------------------------------------===// 00015 00016 #include "llvm/Transforms/Scalar.h" 00017 #include "llvm/Transforms/Utils/BasicBlockUtils.h" 00018 #include "llvm/ADT/STLExtras.h" 00019 #include "llvm/IR/Constants.h" 00020 #include "llvm/IR/Function.h" 00021 #include "llvm/IR/Instructions.h" 00022 #include "llvm/IR/LLVMContext.h" 00023 #include "llvm/IR/CFG.h" 00024 #include "llvm/Pass.h" 00025 #include "llvm/Support/Compiler.h" 00026 #include "llvm/Support/Debug.h" 00027 #include "llvm/Support/raw_ostream.h" 00028 #include "llvm/Transforms/Utils/UnifyFunctionExitNodes.h" 00029 #include <algorithm> 00030 using namespace llvm; 00031 00032 #define DEBUG_TYPE "lower-switch" 00033 00034 namespace { 00035 /// LowerSwitch Pass - Replace all SwitchInst instructions with chained branch 00036 /// instructions. 00037 class LowerSwitch : public FunctionPass { 00038 public: 00039 static char ID; // Pass identification, replacement for typeid 00040 LowerSwitch() : FunctionPass(ID) { 00041 initializeLowerSwitchPass(*PassRegistry::getPassRegistry()); 00042 } 00043 00044 bool runOnFunction(Function &F) override; 00045 00046 void getAnalysisUsage(AnalysisUsage &AU) const override { 00047 // This is a cluster of orthogonal Transforms 00048 AU.addPreserved<UnifyFunctionExitNodes>(); 00049 AU.addPreserved("mem2reg"); 00050 AU.addPreservedID(LowerInvokePassID); 00051 } 00052 00053 struct CaseRange { 00054 Constant* Low; 00055 Constant* High; 00056 BasicBlock* BB; 00057 00058 CaseRange(Constant *low = nullptr, Constant *high = nullptr, 00059 BasicBlock *bb = nullptr) : 00060 Low(low), High(high), BB(bb) { } 00061 }; 00062 00063 typedef std::vector<CaseRange> CaseVector; 00064 typedef std::vector<CaseRange>::iterator CaseItr; 00065 private: 00066 void processSwitchInst(SwitchInst *SI); 00067 00068 BasicBlock *switchConvert(CaseItr Begin, CaseItr End, 00069 ConstantInt *LowerBound, ConstantInt *UpperBound, 00070 Value *Val, BasicBlock *Predecessor, 00071 BasicBlock *OrigBlock, BasicBlock *Default); 00072 BasicBlock *newLeafBlock(CaseRange &Leaf, Value *Val, BasicBlock *OrigBlock, 00073 BasicBlock *Default); 00074 unsigned Clusterify(CaseVector &Cases, SwitchInst *SI); 00075 }; 00076 00077 /// The comparison function for sorting the switch case values in the vector. 00078 /// WARNING: Case ranges should be disjoint! 00079 struct CaseCmp { 00080 bool operator () (const LowerSwitch::CaseRange& C1, 00081 const LowerSwitch::CaseRange& C2) { 00082 00083 const ConstantInt* CI1 = cast<const ConstantInt>(C1.Low); 00084 const ConstantInt* CI2 = cast<const ConstantInt>(C2.High); 00085 return CI1->getValue().slt(CI2->getValue()); 00086 } 00087 }; 00088 } 00089 00090 char LowerSwitch::ID = 0; 00091 INITIALIZE_PASS(LowerSwitch, "lowerswitch", 00092 "Lower SwitchInst's to branches", false, false) 00093 00094 // Publicly exposed interface to pass... 00095 char &llvm::LowerSwitchID = LowerSwitch::ID; 00096 // createLowerSwitchPass - Interface to this file... 00097 FunctionPass *llvm::createLowerSwitchPass() { 00098 return new LowerSwitch(); 00099 } 00100 00101 bool LowerSwitch::runOnFunction(Function &F) { 00102 bool Changed = false; 00103 00104 for (Function::iterator I = F.begin(), E = F.end(); I != E; ) { 00105 BasicBlock *Cur = I++; // Advance over block so we don't traverse new blocks 00106 00107 if (SwitchInst *SI = dyn_cast<SwitchInst>(Cur->getTerminator())) { 00108 Changed = true; 00109 processSwitchInst(SI); 00110 } 00111 } 00112 00113 return Changed; 00114 } 00115 00116 // operator<< - Used for debugging purposes. 00117 // 00118 static raw_ostream& operator<<(raw_ostream &O, 00119 const LowerSwitch::CaseVector &C) 00120 LLVM_ATTRIBUTE_USED; 00121 static raw_ostream& operator<<(raw_ostream &O, 00122 const LowerSwitch::CaseVector &C) { 00123 O << "["; 00124 00125 for (LowerSwitch::CaseVector::const_iterator B = C.begin(), 00126 E = C.end(); B != E; ) { 00127 O << *B->Low << " -" << *B->High; 00128 if (++B != E) O << ", "; 00129 } 00130 00131 return O << "]"; 00132 } 00133 00134 static void fixPhis(BasicBlock *Succ, 00135 BasicBlock *OrigBlock, 00136 BasicBlock *NewNode) { 00137 for (BasicBlock::iterator I = Succ->begin(), 00138 E = Succ->getFirstNonPHI(); 00139 I != E; ++I) { 00140 PHINode *PN = cast<PHINode>(I); 00141 00142 for (unsigned I = 0, E = PN->getNumIncomingValues(); I != E; ++I) { 00143 if (PN->getIncomingBlock(I) == OrigBlock) 00144 PN->setIncomingBlock(I, NewNode); 00145 } 00146 } 00147 } 00148 00149 // switchConvert - Convert the switch statement into a binary lookup of 00150 // the case values. The function recursively builds this tree. 00151 // LowerBound and UpperBound are used to keep track of the bounds for Val 00152 // that have already been checked by a block emitted by one of the previous 00153 // calls to switchConvert in the call stack. 00154 BasicBlock *LowerSwitch::switchConvert(CaseItr Begin, CaseItr End, 00155 ConstantInt *LowerBound, 00156 ConstantInt *UpperBound, Value *Val, 00157 BasicBlock *Predecessor, 00158 BasicBlock *OrigBlock, 00159 BasicBlock *Default) { 00160 unsigned Size = End - Begin; 00161 00162 if (Size == 1) { 00163 // Check if the Case Range is perfectly squeezed in between 00164 // already checked Upper and Lower bounds. If it is then we can avoid 00165 // emitting the code that checks if the value actually falls in the range 00166 // because the bounds already tell us so. 00167 if (Begin->Low == LowerBound && Begin->High == UpperBound) { 00168 fixPhis(Begin->BB, OrigBlock, Predecessor); 00169 return Begin->BB; 00170 } 00171 return newLeafBlock(*Begin, Val, OrigBlock, Default); 00172 } 00173 00174 unsigned Mid = Size / 2; 00175 std::vector<CaseRange> LHS(Begin, Begin + Mid); 00176 DEBUG(dbgs() << "LHS: " << LHS << "\n"); 00177 std::vector<CaseRange> RHS(Begin + Mid, End); 00178 DEBUG(dbgs() << "RHS: " << RHS << "\n"); 00179 00180 CaseRange &Pivot = *(Begin + Mid); 00181 DEBUG(dbgs() << "Pivot ==> " 00182 << cast<ConstantInt>(Pivot.Low)->getValue() 00183 << " -" << cast<ConstantInt>(Pivot.High)->getValue() << "\n"); 00184 00185 // NewLowerBound here should never be the integer minimal value. 00186 // This is because it is computed from a case range that is never 00187 // the smallest, so there is always a case range that has at least 00188 // a smaller value. 00189 ConstantInt *NewLowerBound = cast<ConstantInt>(Pivot.Low); 00190 ConstantInt *NewUpperBound; 00191 00192 // If we don't have a Default block then it means that we can never 00193 // have a value outside of a case range, so set the UpperBound to the highest 00194 // value in the LHS part of the case ranges. 00195 if (Default != nullptr) { 00196 // Because NewLowerBound is never the smallest representable integer 00197 // it is safe here to subtract one. 00198 NewUpperBound = ConstantInt::get(NewLowerBound->getContext(), 00199 NewLowerBound->getValue() - 1); 00200 } else { 00201 CaseItr LastLHS = LHS.begin() + LHS.size() - 1; 00202 NewUpperBound = cast<ConstantInt>(LastLHS->High); 00203 } 00204 00205 DEBUG(dbgs() << "LHS Bounds ==> "; 00206 if (LowerBound) { 00207 dbgs() << cast<ConstantInt>(LowerBound)->getSExtValue(); 00208 } else { 00209 dbgs() << "NONE"; 00210 } 00211 dbgs() << " - " << NewUpperBound->getSExtValue() << "\n"; 00212 dbgs() << "RHS Bounds ==> "; 00213 dbgs() << NewLowerBound->getSExtValue() << " - "; 00214 if (UpperBound) { 00215 dbgs() << cast<ConstantInt>(UpperBound)->getSExtValue() << "\n"; 00216 } else { 00217 dbgs() << "NONE\n"; 00218 }); 00219 00220 // Create a new node that checks if the value is < pivot. Go to the 00221 // left branch if it is and right branch if not. 00222 Function* F = OrigBlock->getParent(); 00223 BasicBlock* NewNode = BasicBlock::Create(Val->getContext(), "NodeBlock"); 00224 00225 ICmpInst* Comp = new ICmpInst(ICmpInst::ICMP_SLT, 00226 Val, Pivot.Low, "Pivot"); 00227 00228 BasicBlock *LBranch = switchConvert(LHS.begin(), LHS.end(), LowerBound, 00229 NewUpperBound, Val, NewNode, OrigBlock, 00230 Default); 00231 BasicBlock *RBranch = switchConvert(RHS.begin(), RHS.end(), NewLowerBound, 00232 UpperBound, Val, NewNode, OrigBlock, 00233 Default); 00234 00235 Function::iterator FI = OrigBlock; 00236 F->getBasicBlockList().insert(++FI, NewNode); 00237 NewNode->getInstList().push_back(Comp); 00238 00239 BranchInst::Create(LBranch, RBranch, Comp, NewNode); 00240 return NewNode; 00241 } 00242 00243 // newLeafBlock - Create a new leaf block for the binary lookup tree. It 00244 // checks if the switch's value == the case's value. If not, then it 00245 // jumps to the default branch. At this point in the tree, the value 00246 // can't be another valid case value, so the jump to the "default" branch 00247 // is warranted. 00248 // 00249 BasicBlock* LowerSwitch::newLeafBlock(CaseRange& Leaf, Value* Val, 00250 BasicBlock* OrigBlock, 00251 BasicBlock* Default) 00252 { 00253 Function* F = OrigBlock->getParent(); 00254 BasicBlock* NewLeaf = BasicBlock::Create(Val->getContext(), "LeafBlock"); 00255 Function::iterator FI = OrigBlock; 00256 F->getBasicBlockList().insert(++FI, NewLeaf); 00257 00258 // Emit comparison 00259 ICmpInst* Comp = nullptr; 00260 if (Leaf.Low == Leaf.High) { 00261 // Make the seteq instruction... 00262 Comp = new ICmpInst(*NewLeaf, ICmpInst::ICMP_EQ, Val, 00263 Leaf.Low, "SwitchLeaf"); 00264 } else { 00265 // Make range comparison 00266 if (cast<ConstantInt>(Leaf.Low)->isMinValue(true /*isSigned*/)) { 00267 // Val >= Min && Val <= Hi --> Val <= Hi 00268 Comp = new ICmpInst(*NewLeaf, ICmpInst::ICMP_SLE, Val, Leaf.High, 00269 "SwitchLeaf"); 00270 } else if (cast<ConstantInt>(Leaf.Low)->isZero()) { 00271 // Val >= 0 && Val <= Hi --> Val <=u Hi 00272 Comp = new ICmpInst(*NewLeaf, ICmpInst::ICMP_ULE, Val, Leaf.High, 00273 "SwitchLeaf"); 00274 } else { 00275 // Emit V-Lo <=u Hi-Lo 00276 Constant* NegLo = ConstantExpr::getNeg(Leaf.Low); 00277 Instruction* Add = BinaryOperator::CreateAdd(Val, NegLo, 00278 Val->getName()+".off", 00279 NewLeaf); 00280 Constant *UpperBound = ConstantExpr::getAdd(NegLo, Leaf.High); 00281 Comp = new ICmpInst(*NewLeaf, ICmpInst::ICMP_ULE, Add, UpperBound, 00282 "SwitchLeaf"); 00283 } 00284 } 00285 00286 // Make the conditional branch... 00287 BasicBlock* Succ = Leaf.BB; 00288 BranchInst::Create(Succ, Default, Comp, NewLeaf); 00289 00290 // If there were any PHI nodes in this successor, rewrite one entry 00291 // from OrigBlock to come from NewLeaf. 00292 for (BasicBlock::iterator I = Succ->begin(); isa<PHINode>(I); ++I) { 00293 PHINode* PN = cast<PHINode>(I); 00294 // Remove all but one incoming entries from the cluster 00295 uint64_t Range = cast<ConstantInt>(Leaf.High)->getSExtValue() - 00296 cast<ConstantInt>(Leaf.Low)->getSExtValue(); 00297 for (uint64_t j = 0; j < Range; ++j) { 00298 PN->removeIncomingValue(OrigBlock); 00299 } 00300 00301 int BlockIdx = PN->getBasicBlockIndex(OrigBlock); 00302 assert(BlockIdx != -1 && "Switch didn't go to this successor??"); 00303 PN->setIncomingBlock((unsigned)BlockIdx, NewLeaf); 00304 } 00305 00306 return NewLeaf; 00307 } 00308 00309 // Clusterify - Transform simple list of Cases into list of CaseRange's 00310 unsigned LowerSwitch::Clusterify(CaseVector& Cases, SwitchInst *SI) { 00311 unsigned numCmps = 0; 00312 00313 // Start with "simple" cases 00314 for (SwitchInst::CaseIt i = SI->case_begin(), e = SI->case_end(); i != e; ++i) 00315 Cases.push_back(CaseRange(i.getCaseValue(), i.getCaseValue(), 00316 i.getCaseSuccessor())); 00317 00318 std::sort(Cases.begin(), Cases.end(), CaseCmp()); 00319 00320 // Merge case into clusters 00321 if (Cases.size()>=2) 00322 for (CaseItr I = Cases.begin(), J = std::next(Cases.begin()); 00323 J != Cases.end();) { 00324 int64_t nextValue = cast<ConstantInt>(J->Low)->getSExtValue(); 00325 int64_t currentValue = cast<ConstantInt>(I->High)->getSExtValue(); 00326 BasicBlock* nextBB = J->BB; 00327 BasicBlock* currentBB = I->BB; 00328 00329 // If the two neighboring cases go to the same destination, merge them 00330 // into a single case. 00331 if ((nextValue-currentValue==1) && (currentBB == nextBB)) { 00332 I->High = J->High; 00333 J = Cases.erase(J); 00334 } else { 00335 I = J++; 00336 } 00337 } 00338 00339 for (CaseItr I=Cases.begin(), E=Cases.end(); I!=E; ++I, ++numCmps) { 00340 if (I->Low != I->High) 00341 // A range counts double, since it requires two compares. 00342 ++numCmps; 00343 } 00344 00345 return numCmps; 00346 } 00347 00348 // processSwitchInst - Replace the specified switch instruction with a sequence 00349 // of chained if-then insts in a balanced binary search. 00350 // 00351 void LowerSwitch::processSwitchInst(SwitchInst *SI) { 00352 BasicBlock *CurBlock = SI->getParent(); 00353 BasicBlock *OrigBlock = CurBlock; 00354 Function *F = CurBlock->getParent(); 00355 Value *Val = SI->getCondition(); // The value we are switching on... 00356 BasicBlock* Default = SI->getDefaultDest(); 00357 00358 // If there is only the default destination, don't bother with the code below. 00359 if (!SI->getNumCases()) { 00360 BranchInst::Create(SI->getDefaultDest(), CurBlock); 00361 CurBlock->getInstList().erase(SI); 00362 return; 00363 } 00364 00365 const bool DefaultIsUnreachable = 00366 Default->size() == 1 && isa<UnreachableInst>(Default->getTerminator()); 00367 // Create a new, empty default block so that the new hierarchy of 00368 // if-then statements go to this and the PHI nodes are happy. 00369 // if the default block is set as an unreachable we avoid creating one 00370 // because will never be a valid target. 00371 BasicBlock *NewDefault = nullptr; 00372 if (!DefaultIsUnreachable) { 00373 NewDefault = BasicBlock::Create(SI->getContext(), "NewDefault"); 00374 F->getBasicBlockList().insert(Default, NewDefault); 00375 00376 BranchInst::Create(Default, NewDefault); 00377 } 00378 // If there is an entry in any PHI nodes for the default edge, make sure 00379 // to update them as well. 00380 for (BasicBlock::iterator I = Default->begin(); isa<PHINode>(I); ++I) { 00381 PHINode *PN = cast<PHINode>(I); 00382 int BlockIdx = PN->getBasicBlockIndex(OrigBlock); 00383 assert(BlockIdx != -1 && "Switch didn't go to this successor??"); 00384 PN->setIncomingBlock((unsigned)BlockIdx, NewDefault); 00385 } 00386 00387 // Prepare cases vector. 00388 CaseVector Cases; 00389 unsigned numCmps = Clusterify(Cases, SI); 00390 00391 DEBUG(dbgs() << "Clusterify finished. Total clusters: " << Cases.size() 00392 << ". Total compares: " << numCmps << "\n"); 00393 DEBUG(dbgs() << "Cases: " << Cases << "\n"); 00394 (void)numCmps; 00395 00396 ConstantInt *UpperBound = nullptr; 00397 ConstantInt *LowerBound = nullptr; 00398 00399 // Optimize the condition where Default is an unreachable block. In this case 00400 // we can make the bounds tightly fitted around the case value ranges, 00401 // because we know that the value passed to the switch should always be 00402 // exactly one of the case values. 00403 if (DefaultIsUnreachable) { 00404 CaseItr LastCase = Cases.begin() + Cases.size() - 1; 00405 UpperBound = cast<ConstantInt>(LastCase->High); 00406 LowerBound = cast<ConstantInt>(Cases.begin()->Low); 00407 } 00408 BasicBlock *SwitchBlock = 00409 switchConvert(Cases.begin(), Cases.end(), LowerBound, UpperBound, Val, 00410 OrigBlock, OrigBlock, NewDefault); 00411 00412 // Branch to our shiny new if-then stuff... 00413 BranchInst::Create(SwitchBlock, OrigBlock); 00414 00415 // We are now done with the switch instruction, delete it. 00416 CurBlock->getInstList().erase(SI); 00417 00418 pred_iterator PI = pred_begin(Default), E = pred_end(Default); 00419 // If the Default block has no more predecessors just remove it 00420 if (PI == E) { 00421 DeleteDeadBlock(Default); 00422 } 00423 }