17#include "llvm/ADT/ArrayRef.h"
18#include "llvm/ADT/Sequence.h"
19#include "llvm/ADT/SmallVector.h"
20#include "llvm/IR/Constants.h"
21#include "llvm/IR/Metadata.h"
22#include "llvm/Transforms/Utils/UnrollLoop.h"
25#define DEBUG_TYPE "polly-opt-isl"
44 Target.
release(), TargetIdx, LoopType))
51 Target.
release(), TargetIdx, IsolateType))
59template <
typename CbTy>
64 bool ExcludeAny =
false;
65 bool IncludeAny =
false;
66 for (
auto OldIdx : seq<int>(0, NumBandDims)) {
67 if (IncludeCb(OldIdx))
83 for (
auto OldIdx : seq<int>(0, NumBandDims)) {
84 if (IncludeCb(OldIdx))
87 List = List.
drop(NewIdx, 1);
94 NewPartialSched = PartialSched;
111 for (
auto OldIdx : seq<int>(0, NumBandDims)) {
112 if (!IncludeCb(OldIdx))
115 applyBandMemberAttributes(std::move(NewBand), NewIdx, OldBand, OldIdx);
130template <
typename Derived,
typename... Args>
131struct ScheduleTreeRewriter
133 Derived &
getDerived() {
return *
static_cast<Derived *
>(
this); }
135 return *
static_cast<const Derived *
>(
this);
146 return rebuildBand(Band, NewChild, [](
int) {
return true; });
154 for (
int i = 1; i < NumChildren; i += 1)
156 getDerived().
visit(Sequence.
child(i), std::forward<Args>(args)...));
164 for (
int i = 1; i < NumChildren; i += 1)
168 .
visit(Set.
child(i), std::forward<Args>(args)...)
182 .visit(Mark.
first_child(), std::forward<Args>(args)...)
193 .visit(Extension.
child(0), args...)
210 llvm_unreachable(
"Not implemented");
216struct IdentityRewriter : ScheduleTreeRewriter<IdentityRewriter> {};
229struct ExtensionNodeRewriter final
230 : ScheduleTreeRewriter<ExtensionNodeRewriter, const isl::union_set &,
232 using BaseTy = ScheduleTreeRewriter<ExtensionNodeRewriter,
234 BaseTy &getBase() {
return *
this; }
235 const BaseTy &getBase()
const {
return *
this; }
250 for (
int i = 1; i < NumChildren; i += 1) {
254 NewNode = NewNode.
sequence(NewChildNode);
255 Extensions = Extensions.
unite(NewChildExtensions);
265 for (
int i = 1; i < NumChildren; i += 1) {
271 Extensions = Extensions.
unite(NewChildExtensions);
299 assert(ExtDims >= BandDims);
300 unsigned OuterDims = ExtDims - BandDims;
304 NewPartialSchedMap = NewPartialSchedMap.
unite(BandSched);
310 OuterExtensions = OuterExtensions.
unite(OuterSched);
324 for (
unsigned i = 0; i < BandDims; i += 1)
339 return visit(Filter.
first_child(), NewDomain, Extensions);
350 visit(Extension.
first_child(), NewDomain, ChildExtensions);
351 Extensions = ChildExtensions.
unite(ExtDomain);
360struct CollectASTBuildOptions final
363 BaseTy &getBase() {
return *
this; }
364 const BaseTy &getBase()
const {
return *
this; }
366 llvm::SmallVector<isl::union_set, 8> ASTBuildOptions;
369 ASTBuildOptions.push_back(
382 BaseTy &getBase() {
return *
this; }
383 const BaseTy &getBase()
const {
return *
this; }
386 llvm::ArrayRef<isl::union_set> ASTBuildOptions;
388 ApplyASTBuildOptions(llvm::ArrayRef<isl::union_set> ASTBuildOptions)
389 : ASTBuildOptions(ASTBuildOptions) {}
394 assert(Pos == ASTBuildOptions.size() &&
395 "AST build options must match to band nodes");
403 return getBase().visitBand(Result);
422 Schedule.
get(), Callback,
nullptr);
430static MDNode *findOptionalNodeOperand(MDNode *LoopMD, StringRef Name) {
431 return dyn_cast_or_null<MDNode>(
479 assert(isBandWithSingleLoop(BandOrMark));
491 MarkOrBand = moveToBandMark(MarkOrBand);
494 if (isMark(MarkOrBand)) {
502 assert(isBandWithSingleLoop(Band));
509 return removeMark(MarkOrBand, Attr);
515 "Don't add a two marks for a band");
533 isl::aff AffFactor{LUnispace, ValFactor};
534 isl::aff AffOffset{LUnispace, ValOffset};
563class BandCollapseRewriter final
564 :
public ScheduleTreeRewriter<BandCollapseRewriter> {
566 using BaseTy = ScheduleTreeRewriter<BandCollapseRewriter>;
567 BaseTy &getBase() {
return *
this; }
568 const BaseTy &getBase()
const {
return *
this; }
579 return getBase().visitBand(Band);
582 SmallVector<isl::schedule_node_band> Nest;
583 int NumTotalLoops = 0;
586 Nest.push_back(Band);
600 if (Nest.size() <= 1)
601 return getBase().visitBand(Band);
604 dbgs() <<
"Found loops to collapse between\n";
618 for (
auto j : seq<int>(0, NumLoops))
619 PartScheds = PartScheds.
add(BandScheds.
at(j));
634 for (
int i : seq<int>(0, NumLoops)) {
635 CollapsedBand = applyBandMemberAttributes(std::move(CollapsedBand),
640 assert(LoopIdx == NumTotalLoops &&
641 "Expect the same number of loops to add up again");
648 POLLY_DEBUG(dbgs() <<
"Collapse bands in schedule\n");
649 BandCollapseRewriter Rewriter;
650 return Rewriter.visit(Sched);
656static void collectPotentiallyFusableBands(
658 SmallVectorImpl<std::pair<isl::schedule_node, isl::schedule_node>>
670 collectPotentiallyFusableBands(
C, ScheduleBands, DirectChild);
671 if (!
C.has_next_sibling())
673 C =
C.next_sibling();
680 ScheduleBands.push_back({Node, DirectChild});
709 return ChildRemainingDeps;
714static isl::union_map remainigDepsFromSequence(ArrayRef<isl::union_set> Domains,
722 for (
auto P : enumerate(Domains)) {
725 PartialSchedules = PartialSchedules.
unite(DomSched.as_union_map());
728 return remainingDepsFromPartialSchedule(PartialSchedules, Deps);
774 if (!canFuseOutermost(LHS, RHS, Deps))
778 dbgs() <<
"Found loops for greedy fusion:\n";
791 IdentityRewriter Rewriter;
801 rebuildBand(LHS, LHSBody, [](
int i) {
return i > 0; });
803 rebuildBand(RHS, RHSBody, [](
int i) {
return i > 0; });
815 return NewCommonSchedule;
837class GreedyFusionRewriter final
838 :
public ScheduleTreeRewriter<GreedyFusionRewriter, isl::union_map> {
840 using BaseTy = ScheduleTreeRewriter<GreedyFusionRewriter, isl::union_map>;
841 BaseTy &getBase() {
return *
this; }
842 const BaseTy &getBase()
const {
return *
this; }
846 bool AnyChange =
false;
870 return getBase().visitBand(Band, RemDeps);
881 SmallVector<std::pair<isl::schedule_node, isl::schedule_node>> Bands;
882 for (
auto i : seq<int>(0, NumChildren)) {
884 collectPotentiallyFusableBands(Child, Bands, Child);
888 SmallDenseSet<isl_schedule_node *, 4> ChangedDirectChildren;
892 while (i + 1 < (
int)Bands.size()) {
894 tryGreedyFuse(Bands[i].first, Bands[i + 1].first, Deps);
902 if (!Bands[i].second.is_null())
903 ChangedDirectChildren.insert(Bands[i].second.get());
904 if (!Bands[i + 1].second.is_null())
905 ChangedDirectChildren.insert(Bands[i + 1].second.get());
910 Bands.erase(Bands.begin() + i + 1);
917 SmallVector<isl::union_set> SubDomains;
918 SubDomains.reserve(NumChildren);
919 for (
int i = 0; i < NumChildren; i += 1)
921 auto SubRemainingDeps = remainigDepsFromSequence(SubDomains, Deps);
925 SmallDenseSet<isl_schedule_node *, 4> AlreadyAdded;
928 for (
auto &P : Bands) {
934 !ChangedDirectChildren.count(DirectChild.
get())) {
935 if (AlreadyAdded.count(DirectChild.
get()))
937 AlreadyAdded.insert(DirectChild.
get());
938 MaybeFused = DirectChild;
941 "Need changed flag for be consistent with actual change");
946 isl::schedule InnerFused = visit(MaybeFused, SubRemainingDeps);
952 Result = Result.
sequence(InnerFused);
962 return isMark(Node) &&
967 MarkOrBand = moveToBandMark(MarkOrBand);
968 if (!isMark(MarkOrBand))
977 if (!containsExtensionNode(Sched))
983 CollectASTBuildOptions Collector;
984 Collector.visit(Sched);
987 ExtensionNodeRewriter Rewriter;
992 ApplyASTBuildOptions Applicator(Collector.ASTBuildOptions);
993 NewSched = Applicator.visitSchedule(NewSched);
1002 BandToUnroll = removeMark(BandToUnroll);
1003 assert(isBandWithSingleLoop(BandToUnroll));
1008 "Can only unroll a single dimension");
1021 SmallVector<isl::point, 16> Elts;
1040 List = List.add(DomainFilter);
1052 assert(Factor > 0 &&
"Positive unroll factor required");
1057 BandToUnroll = removeMark(BandToUnroll, Attr);
1058 assert(isBandWithSingleLoop(BandToUnroll));
1077 StridedPartialSchedUAff = StridedPartialSchedUAff.
union_add(DivSchedAff);
1082 for (
auto i : seq<int>(0, Factor)) {
1093 List = List.
add(UnrolledDomain);
1102 MDNode *FollowupMD =
nullptr;
1105 findOptionalNodeOperand(Attr->
Metadata, LLVMLoopUnrollFollowupUnrolled);
1107 isl::id NewBandId = createGeneratedLoopAttr(
Ctx, FollowupMD);
1109 NewLoop = insertMark(NewLoop, NewBandId);
1120 auto ExtentPrefixes = addExtentConstraints(LoopPrefixes, VectorWidth);
1124 return LoopPrefixes.
subtract(BadPrefixes);
1128 unsigned OutDimsNum) {
1130 assert(OutDimsNum <= Dims &&
1131 "The isl::set IsolateDomain is used to describe the range of schedule "
1132 "dimensions values, which should be isolated. Consequently, the "
1133 "number of its dimensions should be greater than or equal to the "
1134 "number of the schedule dimensions.");
1137 Dims - OutDimsNum, OutDimsNum);
1148 DimOption = DimOption.set_tuple_id(Id);
1153 const char *Identifier,
1154 ArrayRef<int> TileSizes,
1155 int DefaultTileSize) {
1159 std::string IdentifierString(Identifier);
1161 unsigned tileSize = i < TileSizes.size() ? TileSizes[i] : DefaultTileSize;
1162 Sizes = Sizes.set_val(i,
isl::val(Node.
ctx(), tileSize));
1164 auto TileLoopMarkerStr = IdentifierString +
" - Tiles";
1167 Node = Node.
child(0);
1170 Node = Node.
child(0);
1171 auto PointLoopMarkerStr = IdentifierString +
" - Points";
1172 auto PointLoopMarker =
1175 return Node.
child(0);
1179 ArrayRef<int> TileSizes,
1180 int DefaultTileSize) {
1181 Node =
tileNode(Node,
"Register tiling", TileSizes, DefaultTileSize);
1190 SmallVectorImpl<isl::schedule_node> &ScheduleStmts) {
1191 if (isBand(Node) || isLeaf(Node)) {
1192 ScheduleStmts.push_back(Node);
1200 if (!
C.has_next_sibling())
1202 C =
C.next_sibling();
1209 BandToFission = removeMark(BandToFission);
1212 SmallVector<isl::schedule_node> FissionableStmts;
1214 size_t N = FissionableStmts.size();
1218 for (
size_t i = 0; i <
N; ++i) {
1234 GreedyFusionRewriter Rewriter;
1236 if (!Rewriter.AnyChange) {
1243 return collapseBands(Result);
static RegisterPass< ScopOnlyPrinterWrapperPass > N("dot-scops-only", "Polly - Print Scops of function (with no function bodies)")
isl::aff mod(isl::val mod) const
static isl::aff var_on_domain(isl::local_space ls, isl::dim type, unsigned int pos)
isl::basic_map fix_val(isl::dim type, unsigned int pos, isl::val v) const
static isl::basic_map from_aff(isl::aff aff)
isl::basic_set domain() const
isl::constraint set_constant_si(int v) const
static isl::constraint alloc_inequality(isl::local_space ls)
isl::constraint set_coefficient_si(isl::dim type, int pos, int v) const
static isl::id alloc(isl::ctx ctx, const std::string &name, void *user)
isl::map apply_range(isl::map map2) const
static isl::map from_domain(isl::set set)
static isl::map lex_gt(isl::space set_space)
isl::map apply_domain(isl::map map2) const
isl::map move_dims(isl::dim dst_type, unsigned int dst_pos, isl::dim src_type, unsigned int src_pos, unsigned int n) const
static isl::map lex_le(isl::space set_space)
isl::map project_out(isl::dim type, unsigned int first, unsigned int n) const
isl::union_pw_aff at(int pos) const
isl::union_pw_aff get_at(int pos) const
class size dim(isl::dim type) const
isl::union_pw_aff_list list() const
isl::space get_space() const
static isl::multi_union_pw_aff from_union_map(isl::union_map umap)
isl::multi_union_pw_aff add(isl::multi_union_pw_aff multi2) const
static isl::multi_val zero(isl::space space)
isl::val get_coordinate_val(isl::dim type, int pos) const
isl::pw_aff div(isl::pw_aff pa2) const
isl::pw_aff floor() const
isl::space get_space() const
isl::pw_aff mul(isl::pw_aff pwaff2) const
boolean member_get_coincident(int pos) const
schedule_node_band set_ast_build_options(isl::union_set options) const
schedule_node_band member_set_coincident(int pos, int coincident) const
isl::multi_union_pw_aff get_partial_schedule() const
schedule_node_band set_permutable(int permutable) const
boolean permutable() const
class size n_member() const
isl::union_set domain() const
isl::schedule_node insert_mark(isl::id mark) const
isl::schedule_node child(int pos) const
__isl_give isl_schedule_node * release()
isl::schedule_node insert_partial_schedule(isl::multi_union_pw_aff schedule) const
boolean has_children() const
isl::schedule_node insert_sequence(isl::union_set_list filters) const
__isl_give isl_schedule_node * copy() const &
isl::schedule get_schedule() const
isl::schedule_node graft_before(isl::schedule_node graft) const
static isl::schedule_node from_extension(isl::union_map extension)
isl::schedule_node parent() const
isl::schedule_node first_child() const
__isl_keep isl_schedule_node * get() const
isl::union_set get_domain() const
__isl_keep isl_schedule * get() const
isl::schedule insert_partial_schedule(isl::multi_union_pw_aff partial) const
isl::schedule intersect_domain(isl::union_set domain) const
isl::schedule_node get_root() const
static isl::schedule from_domain(isl::union_set domain)
isl::union_set get_domain() const
isl::schedule sequence(isl::schedule schedule2) const
__isl_give isl_schedule * release()
isl::set project_out(isl::dim type, unsigned int first, unsigned int n) const
isl::set subtract(isl::set set2) const
static isl::set universe(isl::space space)
isl::set set_tuple_id(isl::id id) const
class size tuple_dim() const
isl::set add_constraint(isl::constraint constraint) const
isl::space get_space() const
isl::set drop_constraints_involving_dims(isl::dim type, unsigned int first, unsigned int n) const
isl::space add_unnamed_tuple(unsigned int dim) const
isl::space params() const
isl::space domain() const
isl::space set_from_params() const
isl::space add_dims(isl::dim type, unsigned int n) const
isl::union_set range() const
isl::union_map reverse() const
isl::union_map unite(isl::union_map umap2) const
isl::union_set domain() const
isl::map_list get_map_list() const
isl::multi_union_pw_aff as_multi_union_pw_aff() const
isl::space get_space() const
isl::union_map apply_domain(isl::union_map umap2) const
isl::union_map intersect_range(isl::space space) const
static isl::union_map empty(isl::ctx ctx)
isl::union_map intersect(isl::union_map umap2) const
isl::union_map intersect_domain(isl::space space) const
static isl::union_map from(isl::multi_union_pw_aff mupa)
isl::union_pw_aff_list drop(unsigned int first, unsigned int n) const
isl::space get_space() const
isl::multi_union_pw_aff union_add(const isl::multi_union_pw_aff &mupa2) const
stat foreach_pw_aff(const std::function< stat(isl::pw_aff)> &fn) const
isl::union_map as_union_map() const
static isl::union_pw_aff empty(isl::space space)
isl::union_pw_aff intersect_domain(isl::space space) const
isl::union_set_list add(isl::union_set el) const
stat foreach_point(const std::function< stat(isl::point)> &fn) const
boolean lt(const isl::val &v2) const
enum isl_schedule_node_type isl_schedule_node_get_type(__isl_keep isl_schedule_node *node)
static __isl_keep isl_id * get_id(__isl_keep isl_space *space, enum isl_dim_type type, unsigned pos)
static bool is_equal(const T &a, const T &b)
boolean manage(isl_bool val)
This file contains the declaration of the PolyhedralInfo class, which will provide an interface to ex...
unsigned getNumScatterDims(const isl::union_map &Schedule)
Determine how many dimensions the scatter space of Schedule has.
bool isLoopAttr(const isl::id &Id)
Is Id representing a loop?
std::optional< llvm::Metadata * > findMetadataOperand(llvm::MDNode *LoopMD, llvm::StringRef Name)
Find a property value in a LoopID.
isl::schedule_node applyRegisterTiling(isl::schedule_node Node, llvm::ArrayRef< int > TileSizes, int DefaultTileSize)
Tile a schedule node and unroll point loops.
BandAttr * getLoopAttr(const isl::id &Id)
Return the BandAttr of a loop's isl::id.
isl::schedule applyGreedyFusion(isl::schedule Sched, const isl::union_map &Deps)
Apply greedy fusion.
llvm::iota_range< unsigned > rangeIslSize(unsigned Begin, isl::size End)
Check that End is valid and return an iterator from Begin to End.
BandAttr * getBandAttr(isl::schedule_node MarkOrBand)
Extract the BandAttr from a band's wrapping marker.
bool isBandMark(const isl::schedule_node &Node)
Is this node the marker for its parent band?
isl::id getIslLoopAttr(isl::ctx Ctx, BandAttr *Attr)
Get an isl::id representing a loop.
void dumpIslObj(const isl::schedule_node &Node, llvm::raw_ostream &OS)
Emit the equivaltent of the isl_*_dump output into a raw_ostream.
isl::schedule applyMaxFission(isl::schedule_node BandToFission)
Loop-distribute the band BandToFission as much as possible.
isl::union_set getIsolateOptions(isl::set IsolateDomain, unsigned OutDimsNum)
Create an isl::union_set, which describes the isolate option based on IsolateDomain.
isl::schedule applyPartialUnroll(isl::schedule_node BandToUnroll, int Factor)
Replace the AST band BandToUnroll by a partially unrolled equivalent.
isl::schedule_node tileNode(isl::schedule_node Node, const char *Identifier, llvm::ArrayRef< int > TileSizes, int DefaultTileSize)
Tile a schedule node.
isl::union_set getDimOptions(isl::ctx Ctx, const char *Option)
Create an isl::union_set, which describes the specified option for the dimension of the current node.
isl::schedule hoistExtensionNodes(isl::schedule Sched)
Hoist all domains from extension into the root domain node, such that there are no more extension nod...
isl::set getPartialTilePrefixes(isl::set ScheduleRange, int VectorWidth)
Build the desired set of partial tile prefixes.
isl::schedule applyFullUnroll(isl::schedule_node BandToUnroll)
Replace the AST band BandToUnroll by a sequence of all its iterations.
__isl_give isl_schedule * isl_schedule_set(__isl_take isl_schedule *schedule1, __isl_take isl_schedule *schedule2)
isl_stat isl_schedule_foreach_schedule_node_top_down(__isl_keep isl_schedule *sched, isl_bool(*fn)(__isl_keep isl_schedule_node *node, void *user), void *user)
__isl_export isl_size isl_schedule_node_band_n_member(__isl_keep isl_schedule_node *node)
__isl_export __isl_give isl_multi_union_pw_aff * isl_schedule_node_band_get_partial_schedule(__isl_keep isl_schedule_node *node)
__isl_export isl_size isl_schedule_node_n_children(__isl_keep isl_schedule_node *node)
enum isl_ast_loop_type isl_schedule_node_band_member_get_ast_loop_type(__isl_keep isl_schedule_node *node, int pos)
__isl_export __isl_give isl_schedule_node * isl_schedule_node_band_tile(__isl_take isl_schedule_node *node, __isl_take isl_multi_val *sizes)
__isl_give isl_space * isl_schedule_node_band_get_space(__isl_keep isl_schedule_node *node)
__isl_export __isl_give isl_schedule_node * isl_schedule_node_band_set_permutable(__isl_take isl_schedule_node *node, int permutable)
__isl_export __isl_give isl_schedule_node * isl_schedule_node_band_member_set_ast_loop_type(__isl_take isl_schedule_node *node, int pos, enum isl_ast_loop_type type)
__isl_give isl_schedule_node * isl_schedule_node_band_member_set_isolate_ast_loop_type(__isl_take isl_schedule_node *node, int pos, enum isl_ast_loop_type type)
__isl_export isl_bool isl_schedule_node_band_get_permutable(__isl_keep isl_schedule_node *node)
__isl_export __isl_give isl_union_set * isl_schedule_node_band_get_ast_build_options(__isl_keep isl_schedule_node *node)
enum isl_ast_loop_type isl_schedule_node_band_member_get_isolate_ast_loop_type(__isl_keep isl_schedule_node *node, int pos)
__isl_give isl_schedule_node * isl_schedule_node_delete(__isl_take isl_schedule_node *node)
@ isl_schedule_node_filter
@ isl_schedule_node_domain
@ isl_schedule_node_extension
@ isl_schedule_node_sequence
Represent the attributes of a loop.
llvm::MDNode * Metadata
LoopID which stores the properties of the loop, such as transformations to apply and the metadata of ...
Recursively visit all nodes of a schedule tree.
RetTy visitNode(isl::schedule_node Node, Args... args)
By default, recursively visit the child nodes.
RetTy visit(isl::schedule Schedule, Args... args)
When visiting an entire schedule tree, start at its root node.
Recursively visit all nodes of a schedule tree while allowing changes.
RetTy visitExtension(isl::schedule_node_extension Extension, Args... args)
RetTy visitMark(isl::schedule_node_mark Mark, Args... args)
RetTy visitSequence(isl::schedule_node_sequence Sequence, Args... args)
RetTy visitSet(isl::schedule_node_set Set, Args... args)
RetTy visitBand(isl::schedule_node_band Band, Args... args)
RetTy visitDomain(isl::schedule_node_domain Domain, Args... args)
RetTy visitFilter(isl::schedule_node_filter Filter, Args... args)
RetTy visitLeaf(isl::schedule_node_leaf Leaf, Args... args)
static TupleKindPtr Domain("Domain")
static TupleKindPtr Leaf("Leaf")