LLVM API Documentation

SITypeRewriter.cpp
Go to the documentation of this file.
00001 //===-- SITypeRewriter.cpp - Remove unwanted types ------------------------===//
00002 //
00003 //                     The LLVM Compiler Infrastructure
00004 //
00005 // This file is distributed under the University of Illinois Open Source
00006 // License. See LICENSE.TXT for details.
00007 //
00008 //===----------------------------------------------------------------------===//
00009 //
00010 /// \file
00011 /// This pass removes performs the following type substitution on all
00012 /// non-compute shaders:
00013 ///
00014 /// v16i8 => i128
00015 ///   - v16i8 is used for constant memory resource descriptors.  This type is
00016 ///      legal for some compute APIs, and we don't want to declare it as legal
00017 ///      in the backend, because we want the legalizer to expand all v16i8
00018 ///      operations.
00019 /// v1* => *
00020 ///   - Having v1* types complicates the legalizer and we can easily replace
00021 ///   - them with the element type.
00022 //===----------------------------------------------------------------------===//
00023 
00024 #include "AMDGPU.h"
00025 #include "llvm/IR/IRBuilder.h"
00026 #include "llvm/IR/InstVisitor.h"
00027 
00028 using namespace llvm;
00029 
00030 namespace {
00031 
00032 class SITypeRewriter : public FunctionPass,
00033                        public InstVisitor<SITypeRewriter> {
00034 
00035   static char ID;
00036   Module *Mod;
00037   Type *v16i8;
00038   Type *v4i32;
00039 
00040 public:
00041   SITypeRewriter() : FunctionPass(ID) { }
00042   bool doInitialization(Module &M) override;
00043   bool runOnFunction(Function &F) override;
00044   const char *getPassName() const override {
00045     return "SI Type Rewriter";
00046   }
00047   void visitLoadInst(LoadInst &I);
00048   void visitCallInst(CallInst &I);
00049   void visitBitCast(BitCastInst &I);
00050 };
00051 
00052 } // End anonymous namespace
00053 
00054 char SITypeRewriter::ID = 0;
00055 
00056 bool SITypeRewriter::doInitialization(Module &M) {
00057   Mod = &M;
00058   v16i8 = VectorType::get(Type::getInt8Ty(M.getContext()), 16);
00059   v4i32 = VectorType::get(Type::getInt32Ty(M.getContext()), 4);
00060   return false;
00061 }
00062 
00063 bool SITypeRewriter::runOnFunction(Function &F) {
00064   AttributeSet Set = F.getAttributes();
00065   Attribute A = Set.getAttribute(AttributeSet::FunctionIndex, "ShaderType");
00066 
00067   unsigned ShaderType = ShaderType::COMPUTE;
00068   if (A.isStringAttribute()) {
00069     StringRef Str = A.getValueAsString();
00070     Str.getAsInteger(0, ShaderType);
00071   }
00072   if (ShaderType == ShaderType::COMPUTE)
00073     return false;
00074 
00075   visit(F);
00076   visit(F);
00077 
00078   return false;
00079 }
00080 
00081 void SITypeRewriter::visitLoadInst(LoadInst &I) {
00082   Value *Ptr = I.getPointerOperand();
00083   Type *PtrTy = Ptr->getType();
00084   Type *ElemTy = PtrTy->getPointerElementType();
00085   IRBuilder<> Builder(&I);
00086   if (ElemTy == v16i8)  {
00087     Value *BitCast = Builder.CreateBitCast(Ptr,
00088         PointerType::get(v4i32,PtrTy->getPointerAddressSpace()));
00089     LoadInst *Load = Builder.CreateLoad(BitCast);
00090     SmallVector <std::pair<unsigned, MDNode*>, 8> MD;
00091     I.getAllMetadataOtherThanDebugLoc(MD);
00092     for (unsigned i = 0, e = MD.size(); i != e; ++i) {
00093       Load->setMetadata(MD[i].first, MD[i].second);
00094     }
00095     Value *BitCastLoad = Builder.CreateBitCast(Load, I.getType());
00096     I.replaceAllUsesWith(BitCastLoad);
00097     I.eraseFromParent();
00098   }
00099 }
00100 
00101 void SITypeRewriter::visitCallInst(CallInst &I) {
00102   IRBuilder<> Builder(&I);
00103 
00104   SmallVector <Value*, 8> Args;
00105   SmallVector <Type*, 8> Types;
00106   bool NeedToReplace = false;
00107   Function *F = I.getCalledFunction();
00108   std::string Name = F->getName().str();
00109   for (unsigned i = 0, e = I.getNumArgOperands(); i != e; ++i) {
00110     Value *Arg = I.getArgOperand(i);
00111     if (Arg->getType() == v16i8) {
00112       Args.push_back(Builder.CreateBitCast(Arg, v4i32));
00113       Types.push_back(v4i32);
00114       NeedToReplace = true;
00115       Name = Name + ".v4i32";
00116     } else if (Arg->getType()->isVectorTy() &&
00117                Arg->getType()->getVectorNumElements() == 1 &&
00118                Arg->getType()->getVectorElementType() ==
00119                                               Type::getInt32Ty(I.getContext())){
00120       Type *ElementTy = Arg->getType()->getVectorElementType();
00121       std::string TypeName = "i32";
00122       InsertElementInst *Def = cast<InsertElementInst>(Arg);
00123       Args.push_back(Def->getOperand(1));
00124       Types.push_back(ElementTy);
00125       std::string VecTypeName = "v1" + TypeName;
00126       Name = Name.replace(Name.find(VecTypeName), VecTypeName.length(), TypeName);
00127       NeedToReplace = true;
00128     } else {
00129       Args.push_back(Arg);
00130       Types.push_back(Arg->getType());
00131     }
00132   }
00133 
00134   if (!NeedToReplace) {
00135     return;
00136   }
00137   Function *NewF = Mod->getFunction(Name);
00138   if (!NewF) {
00139     NewF = Function::Create(FunctionType::get(F->getReturnType(), Types, false), GlobalValue::ExternalLinkage, Name, Mod);
00140     NewF->setAttributes(F->getAttributes());
00141   }
00142   I.replaceAllUsesWith(Builder.CreateCall(NewF, Args));
00143   I.eraseFromParent();
00144 }
00145 
00146 void SITypeRewriter::visitBitCast(BitCastInst &I) {
00147   IRBuilder<> Builder(&I);
00148   if (I.getDestTy() != v4i32) {
00149     return;
00150   }
00151 
00152   if (BitCastInst *Op = dyn_cast<BitCastInst>(I.getOperand(0))) {
00153     if (Op->getSrcTy() == v4i32) {
00154       I.replaceAllUsesWith(Op->getOperand(0));
00155       I.eraseFromParent();
00156     }
00157   }
00158 }
00159 
00160 FunctionPass *llvm::createSITypeRewriter() {
00161   return new SITypeRewriter();
00162 }