LLVM API Documentation
00001 //===- NVVMReflect.cpp - NVVM Emulate conditional compilation -------------===// 00002 // 00003 // The LLVM Compiler Infrastructure 00004 // 00005 // This file is distributed under the University of Illinois Open Source 00006 // License. See LICENSE.TXT for details. 00007 // 00008 //===----------------------------------------------------------------------===// 00009 // 00010 // This pass replaces occurrences of __nvvm_reflect("string") with an 00011 // integer based on -nvvm-reflect-list string=<int> option given to this pass. 00012 // If an undefined string value is seen in a call to __nvvm_reflect("string"), 00013 // a default value of 0 will be used. 00014 // 00015 //===----------------------------------------------------------------------===// 00016 00017 #include "NVPTX.h" 00018 #include "llvm/ADT/DenseMap.h" 00019 #include "llvm/ADT/SmallVector.h" 00020 #include "llvm/ADT/StringMap.h" 00021 #include "llvm/IR/Constants.h" 00022 #include "llvm/IR/DerivedTypes.h" 00023 #include "llvm/IR/Function.h" 00024 #include "llvm/IR/Instructions.h" 00025 #include "llvm/IR/Intrinsics.h" 00026 #include "llvm/IR/Module.h" 00027 #include "llvm/IR/Type.h" 00028 #include "llvm/Pass.h" 00029 #include "llvm/Support/CommandLine.h" 00030 #include "llvm/Support/Debug.h" 00031 #include "llvm/Support/raw_os_ostream.h" 00032 #include "llvm/Transforms/Scalar.h" 00033 #include <map> 00034 #include <sstream> 00035 #include <string> 00036 #include <vector> 00037 00038 #define NVVM_REFLECT_FUNCTION "__nvvm_reflect" 00039 00040 using namespace llvm; 00041 00042 #define DEBUG_TYPE "nvptx-reflect" 00043 00044 namespace llvm { void initializeNVVMReflectPass(PassRegistry &); } 00045 00046 namespace { 00047 class NVVMReflect : public ModulePass { 00048 private: 00049 StringMap<int> VarMap; 00050 typedef DenseMap<std::string, int>::iterator VarMapIter; 00051 00052 public: 00053 static char ID; 00054 NVVMReflect() : ModulePass(ID) { 00055 initializeNVVMReflectPass(*PassRegistry::getPassRegistry()); 00056 VarMap.clear(); 00057 } 00058 00059 NVVMReflect(const StringMap<int> &Mapping) 00060 : ModulePass(ID) { 00061 initializeNVVMReflectPass(*PassRegistry::getPassRegistry()); 00062 for (StringMap<int>::const_iterator I = Mapping.begin(), E = Mapping.end(); 00063 I != E; ++I) { 00064 VarMap[(*I).getKey()] = (*I).getValue(); 00065 } 00066 } 00067 00068 void getAnalysisUsage(AnalysisUsage &AU) const override { 00069 AU.setPreservesAll(); 00070 } 00071 bool runOnModule(Module &) override; 00072 00073 private: 00074 bool handleFunction(Function *ReflectFunction); 00075 void setVarMap(); 00076 }; 00077 } 00078 00079 ModulePass *llvm::createNVVMReflectPass() { 00080 return new NVVMReflect(); 00081 } 00082 00083 ModulePass *llvm::createNVVMReflectPass(const StringMap<int>& Mapping) { 00084 return new NVVMReflect(Mapping); 00085 } 00086 00087 static cl::opt<bool> 00088 NVVMReflectEnabled("nvvm-reflect-enable", cl::init(true), cl::Hidden, 00089 cl::desc("NVVM reflection, enabled by default")); 00090 00091 char NVVMReflect::ID = 0; 00092 INITIALIZE_PASS(NVVMReflect, "nvvm-reflect", 00093 "Replace occurrences of __nvvm_reflect() calls with 0/1", false, 00094 false) 00095 00096 static cl::list<std::string> 00097 ReflectList("nvvm-reflect-list", cl::value_desc("name=<int>"), cl::Hidden, 00098 cl::desc("A list of string=num assignments"), 00099 cl::ValueRequired); 00100 00101 /// The command line can look as follows : 00102 /// -nvvm-reflect-list a=1,b=2 -nvvm-reflect-list c=3,d=0 -R e=2 00103 /// The strings "a=1,b=2", "c=3,d=0", "e=2" are available in the 00104 /// ReflectList vector. First, each of ReflectList[i] is 'split' 00105 /// using "," as the delimiter. Then each of this part is split 00106 /// using "=" as the delimiter. 00107 void NVVMReflect::setVarMap() { 00108 for (unsigned i = 0, e = ReflectList.size(); i != e; ++i) { 00109 DEBUG(dbgs() << "Option : " << ReflectList[i] << "\n"); 00110 SmallVector<StringRef, 4> NameValList; 00111 StringRef(ReflectList[i]).split(NameValList, ","); 00112 for (unsigned j = 0, ej = NameValList.size(); j != ej; ++j) { 00113 SmallVector<StringRef, 2> NameValPair; 00114 NameValList[j].split(NameValPair, "="); 00115 assert(NameValPair.size() == 2 && "name=val expected"); 00116 std::stringstream ValStream(NameValPair[1]); 00117 int Val; 00118 ValStream >> Val; 00119 assert((!(ValStream.fail())) && "integer value expected"); 00120 VarMap[NameValPair[0]] = Val; 00121 } 00122 } 00123 } 00124 00125 bool NVVMReflect::handleFunction(Function *ReflectFunction) { 00126 // Validate _reflect function 00127 assert(ReflectFunction->isDeclaration() && 00128 "_reflect function should not have a body"); 00129 assert(ReflectFunction->getReturnType()->isIntegerTy() && 00130 "_reflect's return type should be integer"); 00131 00132 std::vector<Instruction *> ToRemove; 00133 00134 // Go through the uses of ReflectFunction in this Function. 00135 // Each of them should a CallInst with a ConstantArray argument. 00136 // First validate that. If the c-string corresponding to the 00137 // ConstantArray can be found successfully, see if it can be 00138 // found in VarMap. If so, replace the uses of CallInst with the 00139 // value found in VarMap. If not, replace the use with value 0. 00140 for (User *U : ReflectFunction->users()) { 00141 assert(isa<CallInst>(U) && "Only a call instruction can use _reflect"); 00142 CallInst *Reflect = cast<CallInst>(U); 00143 00144 assert((Reflect->getNumOperands() == 2) && 00145 "Only one operand expect for _reflect function"); 00146 // In cuda, we will have an extra constant-to-generic conversion of 00147 // the string. 00148 const Value *Str = Reflect->getArgOperand(0); 00149 if (isa<CallInst>(Str)) { 00150 // CUDA path 00151 const CallInst *ConvCall = cast<CallInst>(Str); 00152 Str = ConvCall->getArgOperand(0); 00153 } 00154 assert(isa<ConstantExpr>(Str) && 00155 "Format of _reflect function not recognized"); 00156 const ConstantExpr *GEP = cast<ConstantExpr>(Str); 00157 00158 const Value *Sym = GEP->getOperand(0); 00159 assert(isa<Constant>(Sym) && "Format of _reflect function not recognized"); 00160 00161 const Constant *SymStr = cast<Constant>(Sym); 00162 00163 assert(isa<ConstantDataSequential>(SymStr->getOperand(0)) && 00164 "Format of _reflect function not recognized"); 00165 00166 assert(cast<ConstantDataSequential>(SymStr->getOperand(0))->isCString() && 00167 "Format of _reflect function not recognized"); 00168 00169 std::string ReflectArg = 00170 cast<ConstantDataSequential>(SymStr->getOperand(0))->getAsString(); 00171 00172 ReflectArg = ReflectArg.substr(0, ReflectArg.size() - 1); 00173 DEBUG(dbgs() << "Arg of _reflect : " << ReflectArg << "\n"); 00174 00175 int ReflectVal = 0; // The default value is 0 00176 if (VarMap.find(ReflectArg) != VarMap.end()) { 00177 ReflectVal = VarMap[ReflectArg]; 00178 } 00179 Reflect->replaceAllUsesWith( 00180 ConstantInt::get(Reflect->getType(), ReflectVal)); 00181 ToRemove.push_back(Reflect); 00182 } 00183 if (ToRemove.size() == 0) 00184 return false; 00185 00186 for (unsigned i = 0, e = ToRemove.size(); i != e; ++i) 00187 ToRemove[i]->eraseFromParent(); 00188 return true; 00189 } 00190 00191 bool NVVMReflect::runOnModule(Module &M) { 00192 if (!NVVMReflectEnabled) 00193 return false; 00194 00195 setVarMap(); 00196 00197 00198 bool Res = false; 00199 std::string Name; 00200 Type *Tys[1]; 00201 Type *I8Ty = Type::getInt8Ty(M.getContext()); 00202 Function *ReflectFunction; 00203 00204 // Check for standard overloaded versions of llvm.nvvm.reflect 00205 00206 for (unsigned i = 0; i != 5; ++i) { 00207 Tys[0] = PointerType::get(I8Ty, i); 00208 Name = Intrinsic::getName(Intrinsic::nvvm_reflect, Tys); 00209 ReflectFunction = M.getFunction(Name); 00210 if(ReflectFunction != 0) { 00211 Res |= handleFunction(ReflectFunction); 00212 } 00213 } 00214 00215 ReflectFunction = M.getFunction(NVVM_REFLECT_FUNCTION); 00216 // If reflect function is not used, then there will be 00217 // no entry in the module. 00218 if (ReflectFunction != 0) 00219 Res |= handleFunction(ReflectFunction); 00220 00221 return Res; 00222 }