LLVM API Documentation

AArch64PBQPRegAlloc.cpp
Go to the documentation of this file.
00001 //===-- AArch64PBQPRegAlloc.cpp - AArch64 specific PBQP constraints -------===//
00002 //
00003 //                     The LLVM Compiler Infrastructure
00004 //
00005 // This file is distributed under the University of Illinois Open Source
00006 // License. See LICENSE.TXT for details.
00007 //
00008 //===----------------------------------------------------------------------===//
00009 // This file contains the AArch64 / Cortex-A57 specific register allocation
00010 // constraints for use by the PBQP register allocator.
00011 //
00012 // It is essentially a transcription of what is contained in
00013 // AArch64A57FPLoadBalancing, which tries to use a balanced
00014 // mix of odd and even D-registers when performing a critical sequence of
00015 // independent, non-quadword FP/ASIMD floating-point multiply-accumulates.
00016 //===----------------------------------------------------------------------===//
00017 
00018 #define DEBUG_TYPE "aarch64-pbqp"
00019 
00020 #include "AArch64.h"
00021 #include "AArch64RegisterInfo.h"
00022 
00023 #include "llvm/ADT/SetVector.h"
00024 #include "llvm/CodeGen/LiveIntervalAnalysis.h"
00025 #include "llvm/CodeGen/MachineBasicBlock.h"
00026 #include "llvm/CodeGen/MachineFunction.h"
00027 #include "llvm/CodeGen/MachineRegisterInfo.h"
00028 #include "llvm/CodeGen/RegAllocPBQP.h"
00029 #include "llvm/Support/Debug.h"
00030 #include "llvm/Support/ErrorHandling.h"
00031 #include "llvm/Support/raw_ostream.h"
00032 
00033 #define PBQP_BUILDER PBQPBuilderWithCoalescing
00034 
00035 using namespace llvm;
00036 
00037 namespace {
00038 
00039 #ifndef NDEBUG
00040 bool isFPReg(unsigned reg) {
00041   return AArch64::FPR32RegClass.contains(reg) ||
00042          AArch64::FPR64RegClass.contains(reg) ||
00043          AArch64::FPR128RegClass.contains(reg);
00044 }
00045 #endif
00046 
00047 bool isOdd(unsigned reg) {
00048   switch (reg) {
00049   default:
00050     llvm_unreachable("Register is not from the expected class !");
00051   case AArch64::S1:
00052   case AArch64::S3:
00053   case AArch64::S5:
00054   case AArch64::S7:
00055   case AArch64::S9:
00056   case AArch64::S11:
00057   case AArch64::S13:
00058   case AArch64::S15:
00059   case AArch64::S17:
00060   case AArch64::S19:
00061   case AArch64::S21:
00062   case AArch64::S23:
00063   case AArch64::S25:
00064   case AArch64::S27:
00065   case AArch64::S29:
00066   case AArch64::S31:
00067   case AArch64::D1:
00068   case AArch64::D3:
00069   case AArch64::D5:
00070   case AArch64::D7:
00071   case AArch64::D9:
00072   case AArch64::D11:
00073   case AArch64::D13:
00074   case AArch64::D15:
00075   case AArch64::D17:
00076   case AArch64::D19:
00077   case AArch64::D21:
00078   case AArch64::D23:
00079   case AArch64::D25:
00080   case AArch64::D27:
00081   case AArch64::D29:
00082   case AArch64::D31:
00083   case AArch64::Q1:
00084   case AArch64::Q3:
00085   case AArch64::Q5:
00086   case AArch64::Q7:
00087   case AArch64::Q9:
00088   case AArch64::Q11:
00089   case AArch64::Q13:
00090   case AArch64::Q15:
00091   case AArch64::Q17:
00092   case AArch64::Q19:
00093   case AArch64::Q21:
00094   case AArch64::Q23:
00095   case AArch64::Q25:
00096   case AArch64::Q27:
00097   case AArch64::Q29:
00098   case AArch64::Q31:
00099     return true;
00100   case AArch64::S0:
00101   case AArch64::S2:
00102   case AArch64::S4:
00103   case AArch64::S6:
00104   case AArch64::S8:
00105   case AArch64::S10:
00106   case AArch64::S12:
00107   case AArch64::S14:
00108   case AArch64::S16:
00109   case AArch64::S18:
00110   case AArch64::S20:
00111   case AArch64::S22:
00112   case AArch64::S24:
00113   case AArch64::S26:
00114   case AArch64::S28:
00115   case AArch64::S30:
00116   case AArch64::D0:
00117   case AArch64::D2:
00118   case AArch64::D4:
00119   case AArch64::D6:
00120   case AArch64::D8:
00121   case AArch64::D10:
00122   case AArch64::D12:
00123   case AArch64::D14:
00124   case AArch64::D16:
00125   case AArch64::D18:
00126   case AArch64::D20:
00127   case AArch64::D22:
00128   case AArch64::D24:
00129   case AArch64::D26:
00130   case AArch64::D28:
00131   case AArch64::D30:
00132   case AArch64::Q0:
00133   case AArch64::Q2:
00134   case AArch64::Q4:
00135   case AArch64::Q6:
00136   case AArch64::Q8:
00137   case AArch64::Q10:
00138   case AArch64::Q12:
00139   case AArch64::Q14:
00140   case AArch64::Q16:
00141   case AArch64::Q18:
00142   case AArch64::Q20:
00143   case AArch64::Q22:
00144   case AArch64::Q24:
00145   case AArch64::Q26:
00146   case AArch64::Q28:
00147   case AArch64::Q30:
00148     return false;
00149 
00150   }
00151 }
00152 
00153 bool haveSameParity(unsigned reg1, unsigned reg2) {
00154   assert(isFPReg(reg1) && "Expecting an FP register for reg1");
00155   assert(isFPReg(reg2) && "Expecting an FP register for reg2");
00156 
00157   return isOdd(reg1) == isOdd(reg2);
00158 }
00159 
00160 class A57PBQPBuilder : public PBQP_BUILDER {
00161 public:
00162   A57PBQPBuilder() : PBQP_BUILDER(), TRI(nullptr), LIs(nullptr), Chains() {}
00163 
00164   // Build a PBQP instance to represent the register allocation problem for
00165   // the given MachineFunction.
00166   std::unique_ptr<PBQPRAProblem>
00167   build(MachineFunction *MF, const LiveIntervals *LI,
00168         const MachineBlockFrequencyInfo *blockInfo,
00169         const RegSet &VRegs) override;
00170 
00171 private:
00172   const AArch64RegisterInfo *TRI;
00173   const LiveIntervals *LIs;
00174   SmallSetVector<unsigned, 32> Chains;
00175 
00176   // Return true if reg is a physical register
00177   bool isPhysicalReg(unsigned reg) const {
00178     return TRI->isPhysicalRegister(reg);
00179   }
00180 
00181   // Add the accumulator chaining constraint, inside the chain, i.e. so that
00182   // parity(Rd) == parity(Ra).
00183   // \return true if a constraint was added
00184   bool addIntraChainConstraint(PBQPRAProblem *p, unsigned Rd, unsigned Ra);
00185 
00186   // Add constraints between existing chains
00187   void addInterChainConstraint(PBQPRAProblem *p, unsigned Rd, unsigned Ra);
00188 };
00189 } // Anonymous namespace
00190 
00191 bool A57PBQPBuilder::addIntraChainConstraint(PBQPRAProblem *p, unsigned Rd,
00192                                              unsigned Ra) {
00193   if (Rd == Ra)
00194     return false;
00195 
00196   if (isPhysicalReg(Rd) || isPhysicalReg(Ra)) {
00197     DEBUG(dbgs() << "Rd is a physical reg:" << isPhysicalReg(Rd) << '\n');
00198     DEBUG(dbgs() << "Ra is a physical reg:" << isPhysicalReg(Ra) << '\n');
00199     return false;
00200   }
00201 
00202   const PBQPRAProblem::AllowedSet *vRdAllowed = &p->getAllowedSet(Rd);
00203   const PBQPRAProblem::AllowedSet *vRaAllowed = &p->getAllowedSet(Ra);
00204 
00205   PBQPRAGraph &g = p->getGraph();
00206   PBQPRAGraph::NodeId node1 = p->getNodeForVReg(Rd);
00207   PBQPRAGraph::NodeId node2 = p->getNodeForVReg(Ra);
00208   PBQPRAGraph::EdgeId edge = g.findEdge(node1, node2);
00209 
00210   // The edge does not exist. Create one with the appropriate interference
00211   // costs.
00212   if (edge == g.invalidEdgeId()) {
00213     const LiveInterval &ld = LIs->getInterval(Rd);
00214     const LiveInterval &la = LIs->getInterval(Ra);
00215     bool livesOverlap = ld.overlaps(la);
00216 
00217     PBQP::Matrix costs(vRdAllowed->size() + 1, vRaAllowed->size() + 1, 0);
00218     for (unsigned i = 0, ie = vRdAllowed->size(); i != ie; ++i) {
00219       unsigned pRd = (*vRdAllowed)[i];
00220       for (unsigned j = 0, je = vRaAllowed->size(); j != je; ++j) {
00221         unsigned pRa = (*vRaAllowed)[j];
00222         if (livesOverlap && TRI->regsOverlap(pRd, pRa))
00223           costs[i + 1][j + 1] = std::numeric_limits<PBQP::PBQPNum>::infinity();
00224         else
00225           costs[i + 1][j + 1] = haveSameParity(pRd, pRa) ? 0.0 : 1.0;
00226       }
00227     }
00228     g.addEdge(node1, node2, std::move(costs));
00229     return true;
00230   }
00231 
00232   if (g.getEdgeNode1Id(edge) == node2) {
00233     std::swap(node1, node2);
00234     std::swap(vRdAllowed, vRaAllowed);
00235   }
00236 
00237   // Enforce minCost(sameParity(RaClass)) > maxCost(otherParity(RdClass))
00238   PBQP::Matrix costs(g.getEdgeCosts(edge));
00239   for (unsigned i = 0, ie = vRdAllowed->size(); i != ie; ++i) {
00240     unsigned pRd = (*vRdAllowed)[i];
00241 
00242     // Get the maximum cost (excluding unallocatable reg) for same parity
00243     // registers
00244     PBQP::PBQPNum sameParityMax = std::numeric_limits<PBQP::PBQPNum>::min();
00245     for (unsigned j = 0, je = vRaAllowed->size(); j != je; ++j) {
00246       unsigned pRa = (*vRaAllowed)[j];
00247       if (haveSameParity(pRd, pRa))
00248         if (costs[i + 1][j + 1] !=
00249                 std::numeric_limits<PBQP::PBQPNum>::infinity() &&
00250             costs[i + 1][j + 1] > sameParityMax)
00251           sameParityMax = costs[i + 1][j + 1];
00252     }
00253 
00254     // Ensure all registers with a different parity have a higher cost
00255     // than sameParityMax
00256     for (unsigned j = 0, je = vRaAllowed->size(); j != je; ++j) {
00257       unsigned pRa = (*vRaAllowed)[j];
00258       if (!haveSameParity(pRd, pRa))
00259         if (sameParityMax > costs[i + 1][j + 1])
00260           costs[i + 1][j + 1] = sameParityMax + 1.0;
00261     }
00262   }
00263   g.setEdgeCosts(edge, costs);
00264 
00265   return true;
00266 }
00267 
00268 void
00269 A57PBQPBuilder::addInterChainConstraint(PBQPRAProblem *p, unsigned Rd,
00270                                         unsigned Ra) {
00271   // Do some Chain management
00272   if (Chains.count(Ra)) {
00273     if (Rd != Ra) {
00274       DEBUG(dbgs() << "Moving acc chain from " << PrintReg(Ra, TRI) << " to "
00275                    << PrintReg(Rd, TRI) << '\n';);
00276       Chains.remove(Ra);
00277       Chains.insert(Rd);
00278     }
00279   } else {
00280     DEBUG(dbgs() << "Creating new acc chain for " << PrintReg(Rd, TRI)
00281                  << '\n';);
00282     Chains.insert(Rd);
00283   }
00284 
00285   const LiveInterval &ld = LIs->getInterval(Rd);
00286   for (auto r : Chains) {
00287     // Skip self
00288     if (r == Rd)
00289       continue;
00290 
00291     const LiveInterval &lr = LIs->getInterval(r);
00292     if (ld.overlaps(lr)) {
00293       const PBQPRAProblem::AllowedSet *vRdAllowed = &p->getAllowedSet(Rd);
00294       const PBQPRAProblem::AllowedSet *vRrAllowed = &p->getAllowedSet(r);
00295 
00296       PBQPRAGraph &g = p->getGraph();
00297       PBQPRAGraph::NodeId node1 = p->getNodeForVReg(Rd);
00298       PBQPRAGraph::NodeId node2 = p->getNodeForVReg(r);
00299       PBQPRAGraph::EdgeId edge = g.findEdge(node1, node2);
00300       assert(edge != g.invalidEdgeId() &&
00301              "PBQP error ! The edge should exist !");
00302 
00303       DEBUG(dbgs() << "Refining constraint !\n";);
00304 
00305       if (g.getEdgeNode1Id(edge) == node2) {
00306         std::swap(node1, node2);
00307         std::swap(vRdAllowed, vRrAllowed);
00308       }
00309 
00310       // Enforce that cost is higher with all other Chains of the same parity
00311       PBQP::Matrix costs(g.getEdgeCosts(edge));
00312       for (unsigned i = 0, ie = vRdAllowed->size(); i != ie; ++i) {
00313         unsigned pRd = (*vRdAllowed)[i];
00314 
00315         // Get the maximum cost (excluding unallocatable reg) for all other
00316         // parity registers
00317         PBQP::PBQPNum sameParityMax = std::numeric_limits<PBQP::PBQPNum>::min();
00318         for (unsigned j = 0, je = vRrAllowed->size(); j != je; ++j) {
00319           unsigned pRa = (*vRrAllowed)[j];
00320           if (!haveSameParity(pRd, pRa))
00321             if (costs[i + 1][j + 1] !=
00322                     std::numeric_limits<PBQP::PBQPNum>::infinity() &&
00323                 costs[i + 1][j + 1] > sameParityMax)
00324               sameParityMax = costs[i + 1][j + 1];
00325         }
00326 
00327         // Ensure all registers with same parity have a higher cost
00328         // than sameParityMax
00329         for (unsigned j = 0, je = vRrAllowed->size(); j != je; ++j) {
00330           unsigned pRa = (*vRrAllowed)[j];
00331           if (haveSameParity(pRd, pRa))
00332             if (sameParityMax > costs[i + 1][j + 1])
00333               costs[i + 1][j + 1] = sameParityMax + 1.0;
00334         }
00335       }
00336       g.setEdgeCosts(edge, costs);
00337     }
00338   }
00339 }
00340 
00341 std::unique_ptr<PBQPRAProblem>
00342 A57PBQPBuilder::build(MachineFunction *MF, const LiveIntervals *LI,
00343                       const MachineBlockFrequencyInfo *blockInfo,
00344                       const RegSet &VRegs) {
00345   std::unique_ptr<PBQPRAProblem> p =
00346       PBQP_BUILDER::build(MF, LI, blockInfo, VRegs);
00347 
00348   TRI = static_cast<const AArch64RegisterInfo *>(
00349       MF->getTarget().getSubtargetImpl()->getRegisterInfo());
00350   LIs = LI;
00351 
00352   DEBUG(MF->dump(););
00353 
00354   for (MachineFunction::const_iterator mbbItr = MF->begin(), mbbEnd = MF->end();
00355        mbbItr != mbbEnd; ++mbbItr) {
00356     const MachineBasicBlock *MBB = &*mbbItr;
00357     Chains.clear(); // FIXME: really needed ? Could not work at MF level ?
00358 
00359     for (MachineBasicBlock::const_iterator miItr = MBB->begin(),
00360                                            miEnd = MBB->end();
00361          miItr != miEnd; ++miItr) {
00362       const MachineInstr *MI = &*miItr;
00363       switch (MI->getOpcode()) {
00364       case AArch64::FMSUBSrrr:
00365       case AArch64::FMADDSrrr:
00366       case AArch64::FNMSUBSrrr:
00367       case AArch64::FNMADDSrrr:
00368       case AArch64::FMSUBDrrr:
00369       case AArch64::FMADDDrrr:
00370       case AArch64::FNMSUBDrrr:
00371       case AArch64::FNMADDDrrr: {
00372         unsigned Rd = MI->getOperand(0).getReg();
00373         unsigned Ra = MI->getOperand(3).getReg();
00374 
00375         if (addIntraChainConstraint(p.get(), Rd, Ra))
00376           addInterChainConstraint(p.get(), Rd, Ra);
00377         break;
00378       }
00379 
00380       case AArch64::FMLAv2f32:
00381       case AArch64::FMLSv2f32: {
00382         unsigned Rd = MI->getOperand(0).getReg();
00383         addInterChainConstraint(p.get(), Rd, Rd);
00384         break;
00385       }
00386 
00387       default:
00388         // Forget Chains which have been killed
00389         for (auto r : Chains) {
00390           SmallVector<unsigned, 8> toDel;
00391           if (MI->killsRegister(r)) {
00392             DEBUG(dbgs() << "Killing chain " << PrintReg(r, TRI) << " at ";
00393                   MI->print(dbgs()););
00394             toDel.push_back(r);
00395           }
00396 
00397           while (!toDel.empty()) {
00398             Chains.remove(toDel.back());
00399             toDel.pop_back();
00400           }
00401         }
00402       }
00403     }
00404   }
00405 
00406   return p;
00407 }
00408 
00409 // Factory function used by AArch64TargetMachine to add the pass to the
00410 // passmanager.
00411 FunctionPass *llvm::createAArch64A57PBQPRegAlloc() {
00412   std::unique_ptr<PBQP_BUILDER> builder = llvm::make_unique<A57PBQPBuilder>();
00413   return createPBQPRegisterAllocator(std::move(builder), nullptr);
00414 }