Polly 20.0.0git
ScheduleTreeTransform.h
Go to the documentation of this file.
1//===- polly/ScheduleTreeTransform.h ----------------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Make changes to isl's schedule tree data structure.
10//
11//===----------------------------------------------------------------------===//
12
13#ifndef POLLY_SCHEDULETREETRANSFORM_H
14#define POLLY_SCHEDULETREETRANSFORM_H
15
17#include "llvm/ADT/ArrayRef.h"
18#include "llvm/Support/ErrorHandling.h"
20#include <cassert>
21
22namespace polly {
23struct BandAttr;
24
25/// This class defines a simple visitor class that may be used for
26/// various schedule tree analysis purposes.
27template <typename Derived, typename RetTy = void, typename... Args>
29 Derived &getDerived() { return *static_cast<Derived *>(this); }
30 const Derived &getDerived() const {
31 return *static_cast<const Derived *>(this);
32 }
33
34 RetTy visit(isl::schedule_node Node, Args... args) {
35 assert(!Node.is_null());
36 switch (isl_schedule_node_get_type(Node.get())) {
39 return getDerived().visitDomain(Node.as<isl::schedule_node_domain>(),
40 std::forward<Args>(args)...);
43 return getDerived().visitBand(Node.as<isl::schedule_node_band>(),
44 std::forward<Args>(args)...);
47 return getDerived().visitSequence(Node.as<isl::schedule_node_sequence>(),
48 std::forward<Args>(args)...);
51 return getDerived().visitSet(Node.as<isl::schedule_node_set>(),
52 std::forward<Args>(args)...);
55 return getDerived().visitLeaf(Node.as<isl::schedule_node_leaf>(),
56 std::forward<Args>(args)...);
59 return getDerived().visitMark(Node.as<isl::schedule_node_mark>(),
60 std::forward<Args>(args)...);
63 return getDerived().visitExtension(
64 Node.as<isl::schedule_node_extension>(), std::forward<Args>(args)...);
67 return getDerived().visitFilter(Node.as<isl::schedule_node_filter>(),
68 std::forward<Args>(args)...);
69 default:
70 llvm_unreachable("unimplemented schedule node type");
71 }
72 }
73
75 return getDerived().visitSingleChild(std::move(Domain),
76 std::forward<Args>(args)...);
77 }
78
79 RetTy visitBand(isl::schedule_node_band Band, Args... args) {
80 return getDerived().visitSingleChild(std::move(Band),
81 std::forward<Args>(args)...);
82 }
83
84 RetTy visitSequence(isl::schedule_node_sequence Sequence, Args... args) {
85 return getDerived().visitMultiChild(std::move(Sequence),
86 std::forward<Args>(args)...);
87 }
88
89 RetTy visitSet(isl::schedule_node_set Set, Args... args) {
90 return getDerived().visitMultiChild(std::move(Set),
91 std::forward<Args>(args)...);
92 }
93
94 RetTy visitLeaf(isl::schedule_node_leaf Leaf, Args... args) {
95 return getDerived().visitNode(std::move(Leaf), std::forward<Args>(args)...);
96 }
97
98 RetTy visitMark(isl::schedule_node_mark Mark, Args... args) {
99 return getDerived().visitSingleChild(std::move(Mark),
100 std::forward<Args>(args)...);
101 }
102
103 RetTy visitExtension(isl::schedule_node_extension Extension, Args... args) {
104 return getDerived().visitSingleChild(std::move(Extension),
105 std::forward<Args>(args)...);
106 }
107
108 RetTy visitFilter(isl::schedule_node_filter Filter, Args... args) {
109 return getDerived().visitSingleChild(std::move(Filter),
110 std::forward<Args>(args)...);
111 }
112
113 RetTy visitSingleChild(isl::schedule_node Node, Args... args) {
114 return getDerived().visitNode(std::move(Node), std::forward<Args>(args)...);
115 }
116
117 RetTy visitMultiChild(isl::schedule_node Node, Args... args) {
118 return getDerived().visitNode(std::move(Node), std::forward<Args>(args)...);
119 }
120
121 RetTy visitNode(isl::schedule_node Node, Args... args) {
122 llvm_unreachable("Unimplemented other");
123 }
124};
125
126/// Recursively visit all nodes of a schedule tree.
127template <typename Derived, typename RetTy = void, typename... Args>
129 : ScheduleTreeVisitor<Derived, RetTy, Args...> {
130 using BaseTy = ScheduleTreeVisitor<Derived, RetTy, Args...>;
131 BaseTy &getBase() { return *this; }
132 const BaseTy &getBase() const { return *this; }
133 Derived &getDerived() { return *static_cast<Derived *>(this); }
134 const Derived &getDerived() const {
135 return *static_cast<const Derived *>(this);
136 }
137
138 /// When visiting an entire schedule tree, start at its root node.
139 RetTy visit(isl::schedule Schedule, Args... args) {
140 return getDerived().visit(Schedule.get_root(), std::forward<Args>(args)...);
141 }
142
143 // Necessary to allow overload resolution with the added visit(isl::schedule)
144 // overload.
145 RetTy visit(isl::schedule_node Node, Args... args) {
146 return getBase().visit(Node, std::forward<Args>(args)...);
147 }
148
149 /// By default, recursively visit the child nodes.
150 RetTy visitNode(isl::schedule_node Node, Args... args) {
151 for (unsigned i : rangeIslSize(0, Node.n_children()))
152 getDerived().visit(Node.child(i), std::forward<Args>(args)...);
153 return RetTy();
154 }
155};
156
157/// Recursively visit all nodes of a schedule tree while allowing changes.
158///
159/// The visit methods return an isl::schedule_node that is used to continue
160/// visiting the tree. Structural changes such as returning a different node
161/// will confuse the visitor.
162template <typename Derived, typename... Args>
164 : public RecursiveScheduleTreeVisitor<Derived, isl::schedule_node,
165 Args...> {
166 Derived &getDerived() { return *static_cast<Derived *>(this); }
167 const Derived &getDerived() const {
168 return *static_cast<const Derived *>(this);
169 }
170
172 return getDerived().visitChildren(Node);
173 }
174
176 if (!Node.has_children())
177 return Node;
178
180 while (true) {
181 It = getDerived().visit(It, std::forward<Args>(args)...);
182 if (!It.has_next_sibling())
183 break;
184 It = It.next_sibling();
185 }
186 return It.parent();
187 }
188};
189
190/// Is this node the marker for its parent band?
191bool isBandMark(const isl::schedule_node &Node);
192
193/// Extract the BandAttr from a band's wrapping marker. Can also pass the band
194/// itself and this methods will try to find its wrapping mark. Returns nullptr
195/// if the band has not BandAttr.
196BandAttr *getBandAttr(isl::schedule_node MarkOrBand);
197
198/// Hoist all domains from extension into the root domain node, such that there
199/// are no more extension nodes (which isl does not support for some
200/// operations). This assumes that domains added by to extension nodes do not
201/// overlap.
203
204/// Replace the AST band @p BandToUnroll by a sequence of all its iterations.
205///
206/// The implementation enumerates all points in the partial schedule and creates
207/// an ISL sequence node for each point. The number of iterations must be a
208/// constant.
210
211/// Replace the AST band @p BandToUnroll by a partially unrolled equivalent.
212isl::schedule applyPartialUnroll(isl::schedule_node BandToUnroll, int Factor);
213
214/// Loop-distribute the band @p BandToFission as much as possible.
216
217/// Build the desired set of partial tile prefixes.
218///
219/// We build a set of partial tile prefixes, which are prefixes of the vector
220/// loop that have exactly VectorWidth iterations.
221///
222/// 1. Drop all constraints involving the dimension that represents the
223/// vector loop.
224/// 2. Constrain the last dimension to get a set, which has exactly VectorWidth
225/// iterations.
226/// 3. Subtract loop domain from it, project out the vector loop dimension and
227/// get a set that contains prefixes, which do not have exactly VectorWidth
228/// iterations.
229/// 4. Project out the vector loop dimension of the set that was build on the
230/// first step and subtract the set built on the previous step to get the
231/// desired set of prefixes.
232///
233/// @param ScheduleRange A range of a map, which describes a prefix schedule
234/// relation.
235isl::set getPartialTilePrefixes(isl::set ScheduleRange, int VectorWidth);
236
237/// Create an isl::union_set, which describes the isolate option based on
238/// IsolateDomain.
239///
240/// @param IsolateDomain An isl::set whose @p OutDimsNum last dimensions should
241/// belong to the current band node.
242/// @param OutDimsNum A number of dimensions that should belong to
243/// the current band node.
244isl::union_set getIsolateOptions(isl::set IsolateDomain, unsigned OutDimsNum);
245
246/// Create an isl::union_set, which describes the specified option for the
247/// dimension of the current node.
248///
249/// @param Ctx An isl::ctx, which is used to create the isl::union_set.
250/// @param Option The name of the option.
251isl::union_set getDimOptions(isl::ctx Ctx, const char *Option);
252
253/// Tile a schedule node.
254///
255/// @param Node The node to tile.
256/// @param Identifier An name that identifies this kind of tiling and
257/// that is used to mark the tiled loops in the
258/// generated AST.
259/// @param TileSizes A vector of tile sizes that should be used for
260/// tiling.
261/// @param DefaultTileSize A default tile size that is used for dimensions
262/// that are not covered by the TileSizes vector.
263isl::schedule_node tileNode(isl::schedule_node Node, const char *Identifier,
264 llvm::ArrayRef<int> TileSizes, int DefaultTileSize);
265
266/// Tile a schedule node and unroll point loops.
267///
268/// @param Node The node to register tile.
269/// @param TileSizes A vector of tile sizes that should be used for
270/// tiling.
271/// @param DefaultTileSize A default tile size that is used for dimensions
273 llvm::ArrayRef<int> TileSizes,
274 int DefaultTileSize);
275
276/// Apply greedy fusion. That is, fuse any loop that is possible to be fused
277/// top-down.
278///
279/// @param Sched Sched tree to fuse all the loops in.
280/// @param Deps Validity constraints that must be preserved.
282 const isl::union_map &Deps);
283
284} // namespace polly
285
286#endif // POLLY_SCHEDULETREETRANSFORM_H
boolean has_next_sibling() const
isl::schedule_node child(int pos) const
boolean has_children() const
isl::schedule_node next_sibling() const
class size n_children() const
isl::schedule_node parent() const
isl::schedule_node first_child() const
__isl_keep isl_schedule_node * get() const
isl::schedule_node get_root() const
enum isl_schedule_node_type isl_schedule_node_get_type(__isl_keep isl_schedule_node *node)
#define assert(exp)
isl::schedule_node applyRegisterTiling(isl::schedule_node Node, llvm::ArrayRef< int > TileSizes, int DefaultTileSize)
Tile a schedule node and unroll point loops.
isl::schedule applyGreedyFusion(isl::schedule Sched, const isl::union_map &Deps)
Apply greedy fusion.
llvm::iota_range< unsigned > rangeIslSize(unsigned Begin, isl::size End)
Check that End is valid and return an iterator from Begin to End.
Definition: ISLTools.cpp:597
BandAttr * getBandAttr(isl::schedule_node MarkOrBand)
Extract the BandAttr from a band's wrapping marker.
bool isBandMark(const isl::schedule_node &Node)
Is this node the marker for its parent band?
isl::schedule applyMaxFission(isl::schedule_node BandToFission)
Loop-distribute the band BandToFission as much as possible.
isl::union_set getIsolateOptions(isl::set IsolateDomain, unsigned OutDimsNum)
Create an isl::union_set, which describes the isolate option based on IsolateDomain.
isl::schedule applyPartialUnroll(isl::schedule_node BandToUnroll, int Factor)
Replace the AST band BandToUnroll by a partially unrolled equivalent.
isl::schedule_node tileNode(isl::schedule_node Node, const char *Identifier, llvm::ArrayRef< int > TileSizes, int DefaultTileSize)
Tile a schedule node.
isl::union_set getDimOptions(isl::ctx Ctx, const char *Option)
Create an isl::union_set, which describes the specified option for the dimension of the current node.
isl::schedule hoistExtensionNodes(isl::schedule Sched)
Hoist all domains from extension into the root domain node, such that there are no more extension nod...
isl::set getPartialTilePrefixes(isl::set ScheduleRange, int VectorWidth)
Build the desired set of partial tile prefixes.
isl::schedule applyFullUnroll(isl::schedule_node BandToUnroll)
Replace the AST band BandToUnroll by a sequence of all its iterations.
__isl_export isl_size isl_schedule_node_n_children(__isl_keep isl_schedule_node *node)
@ isl_schedule_node_mark
Definition: schedule_type.h:18
@ isl_schedule_node_filter
Definition: schedule_type.h:15
@ isl_schedule_node_domain
Definition: schedule_type.h:12
@ isl_schedule_node_band
Definition: schedule_type.h:10
@ isl_schedule_node_set
Definition: schedule_type.h:20
@ isl_schedule_node_extension
Definition: schedule_type.h:14
@ isl_schedule_node_sequence
Definition: schedule_type.h:19
@ isl_schedule_node_leaf
Definition: schedule_type.h:16
Recursively visit all nodes of a schedule tree.
RetTy visitNode(isl::schedule_node Node, Args... args)
By default, recursively visit the child nodes.
RetTy visit(isl::schedule Schedule, Args... args)
When visiting an entire schedule tree, start at its root node.
RetTy visit(isl::schedule_node Node, Args... args)
Recursively visit all nodes of a schedule tree while allowing changes.
const Derived & getDerived() const
isl::schedule_node visitChildren(isl::schedule_node Node, Args... args)
isl::schedule_node visitNode(isl::schedule_node Node, Args... args)
This class defines a simple visitor class that may be used for various schedule tree analysis purpose...
RetTy visitExtension(isl::schedule_node_extension Extension, Args... args)
RetTy visitMark(isl::schedule_node_mark Mark, Args... args)
RetTy visitSequence(isl::schedule_node_sequence Sequence, Args... args)
RetTy visitSingleChild(isl::schedule_node Node, Args... args)
RetTy visitSet(isl::schedule_node_set Set, Args... args)
const Derived & getDerived() const
RetTy visitMultiChild(isl::schedule_node Node, Args... args)
RetTy visitBand(isl::schedule_node_band Band, Args... args)
RetTy visitDomain(isl::schedule_node_domain Domain, Args... args)
RetTy visitFilter(isl::schedule_node_filter Filter, Args... args)
RetTy visitNode(isl::schedule_node Node, Args... args)
RetTy visitLeaf(isl::schedule_node_leaf Leaf, Args... args)
RetTy visit(isl::schedule_node Node, Args... args)
static TupleKindPtr Domain("Domain")
static TupleKindPtr Leaf("Leaf")
static TupleKindPtr Ctx