LLVM API Documentation
00001 //===-- SITypeRewriter.cpp - Remove unwanted types ------------------------===// 00002 // 00003 // The LLVM Compiler Infrastructure 00004 // 00005 // This file is distributed under the University of Illinois Open Source 00006 // License. See LICENSE.TXT for details. 00007 // 00008 //===----------------------------------------------------------------------===// 00009 // 00010 /// \file 00011 /// This pass removes performs the following type substitution on all 00012 /// non-compute shaders: 00013 /// 00014 /// v16i8 => i128 00015 /// - v16i8 is used for constant memory resource descriptors. This type is 00016 /// legal for some compute APIs, and we don't want to declare it as legal 00017 /// in the backend, because we want the legalizer to expand all v16i8 00018 /// operations. 00019 /// v1* => * 00020 /// - Having v1* types complicates the legalizer and we can easily replace 00021 /// - them with the element type. 00022 //===----------------------------------------------------------------------===// 00023 00024 #include "AMDGPU.h" 00025 #include "llvm/IR/IRBuilder.h" 00026 #include "llvm/IR/InstVisitor.h" 00027 00028 using namespace llvm; 00029 00030 namespace { 00031 00032 class SITypeRewriter : public FunctionPass, 00033 public InstVisitor<SITypeRewriter> { 00034 00035 static char ID; 00036 Module *Mod; 00037 Type *v16i8; 00038 Type *v4i32; 00039 00040 public: 00041 SITypeRewriter() : FunctionPass(ID) { } 00042 bool doInitialization(Module &M) override; 00043 bool runOnFunction(Function &F) override; 00044 const char *getPassName() const override { 00045 return "SI Type Rewriter"; 00046 } 00047 void visitLoadInst(LoadInst &I); 00048 void visitCallInst(CallInst &I); 00049 void visitBitCast(BitCastInst &I); 00050 }; 00051 00052 } // End anonymous namespace 00053 00054 char SITypeRewriter::ID = 0; 00055 00056 bool SITypeRewriter::doInitialization(Module &M) { 00057 Mod = &M; 00058 v16i8 = VectorType::get(Type::getInt8Ty(M.getContext()), 16); 00059 v4i32 = VectorType::get(Type::getInt32Ty(M.getContext()), 4); 00060 return false; 00061 } 00062 00063 bool SITypeRewriter::runOnFunction(Function &F) { 00064 AttributeSet Set = F.getAttributes(); 00065 Attribute A = Set.getAttribute(AttributeSet::FunctionIndex, "ShaderType"); 00066 00067 unsigned ShaderType = ShaderType::COMPUTE; 00068 if (A.isStringAttribute()) { 00069 StringRef Str = A.getValueAsString(); 00070 Str.getAsInteger(0, ShaderType); 00071 } 00072 if (ShaderType == ShaderType::COMPUTE) 00073 return false; 00074 00075 visit(F); 00076 visit(F); 00077 00078 return false; 00079 } 00080 00081 void SITypeRewriter::visitLoadInst(LoadInst &I) { 00082 Value *Ptr = I.getPointerOperand(); 00083 Type *PtrTy = Ptr->getType(); 00084 Type *ElemTy = PtrTy->getPointerElementType(); 00085 IRBuilder<> Builder(&I); 00086 if (ElemTy == v16i8) { 00087 Value *BitCast = Builder.CreateBitCast(Ptr, 00088 PointerType::get(v4i32,PtrTy->getPointerAddressSpace())); 00089 LoadInst *Load = Builder.CreateLoad(BitCast); 00090 SmallVector <std::pair<unsigned, MDNode*>, 8> MD; 00091 I.getAllMetadataOtherThanDebugLoc(MD); 00092 for (unsigned i = 0, e = MD.size(); i != e; ++i) { 00093 Load->setMetadata(MD[i].first, MD[i].second); 00094 } 00095 Value *BitCastLoad = Builder.CreateBitCast(Load, I.getType()); 00096 I.replaceAllUsesWith(BitCastLoad); 00097 I.eraseFromParent(); 00098 } 00099 } 00100 00101 void SITypeRewriter::visitCallInst(CallInst &I) { 00102 IRBuilder<> Builder(&I); 00103 00104 SmallVector <Value*, 8> Args; 00105 SmallVector <Type*, 8> Types; 00106 bool NeedToReplace = false; 00107 Function *F = I.getCalledFunction(); 00108 std::string Name = F->getName().str(); 00109 for (unsigned i = 0, e = I.getNumArgOperands(); i != e; ++i) { 00110 Value *Arg = I.getArgOperand(i); 00111 if (Arg->getType() == v16i8) { 00112 Args.push_back(Builder.CreateBitCast(Arg, v4i32)); 00113 Types.push_back(v4i32); 00114 NeedToReplace = true; 00115 Name = Name + ".v4i32"; 00116 } else if (Arg->getType()->isVectorTy() && 00117 Arg->getType()->getVectorNumElements() == 1 && 00118 Arg->getType()->getVectorElementType() == 00119 Type::getInt32Ty(I.getContext())){ 00120 Type *ElementTy = Arg->getType()->getVectorElementType(); 00121 std::string TypeName = "i32"; 00122 InsertElementInst *Def = cast<InsertElementInst>(Arg); 00123 Args.push_back(Def->getOperand(1)); 00124 Types.push_back(ElementTy); 00125 std::string VecTypeName = "v1" + TypeName; 00126 Name = Name.replace(Name.find(VecTypeName), VecTypeName.length(), TypeName); 00127 NeedToReplace = true; 00128 } else { 00129 Args.push_back(Arg); 00130 Types.push_back(Arg->getType()); 00131 } 00132 } 00133 00134 if (!NeedToReplace) { 00135 return; 00136 } 00137 Function *NewF = Mod->getFunction(Name); 00138 if (!NewF) { 00139 NewF = Function::Create(FunctionType::get(F->getReturnType(), Types, false), GlobalValue::ExternalLinkage, Name, Mod); 00140 NewF->setAttributes(F->getAttributes()); 00141 } 00142 I.replaceAllUsesWith(Builder.CreateCall(NewF, Args)); 00143 I.eraseFromParent(); 00144 } 00145 00146 void SITypeRewriter::visitBitCast(BitCastInst &I) { 00147 IRBuilder<> Builder(&I); 00148 if (I.getDestTy() != v4i32) { 00149 return; 00150 } 00151 00152 if (BitCastInst *Op = dyn_cast<BitCastInst>(I.getOperand(0))) { 00153 if (Op->getSrcTy() == v4i32) { 00154 I.replaceAllUsesWith(Op->getOperand(0)); 00155 I.eraseFromParent(); 00156 } 00157 } 00158 } 00159 00160 FunctionPass *llvm::createSITypeRewriter() { 00161 return new SITypeRewriter(); 00162 }