Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 38 additions & 47 deletions loopy/schedule/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,20 +27,11 @@

import logging
import sys
from collections.abc import Hashable, Iterator, Mapping, Sequence, Set
from dataclasses import dataclass, replace
from typing import (
TYPE_CHECKING,
AbstractSet,
Any,
Dict,
FrozenSet,
Hashable,
Iterator,
Mapping,
Optional,
Sequence,
Set,
Tuple,
TypeVar,
)

Expand Down Expand Up @@ -155,7 +146,7 @@ class Barrier(ScheduleItem):

def gather_schedule_block(
schedule: Sequence[ScheduleItem], start_idx: int
) -> Tuple[Sequence[ScheduleItem], int]:
) -> tuple[Sequence[ScheduleItem], int]:
assert isinstance(schedule[start_idx], BeginBlockItem)
level = 0

Expand All @@ -176,7 +167,7 @@ def gather_schedule_block(

def generate_sub_sched_items(
schedule: Sequence[ScheduleItem], start_idx: int
) -> Iterator[Tuple[int, ScheduleItem]]:
) -> Iterator[tuple[int, ScheduleItem]]:
if not isinstance(schedule[start_idx], BeginBlockItem):
yield start_idx, schedule[start_idx]

Expand All @@ -203,7 +194,7 @@ def generate_sub_sched_items(

def get_insn_ids_for_block_at(
schedule: Sequence[ScheduleItem], start_idx: int
) -> FrozenSet[str]:
) -> frozenset[str]:
return frozenset(
sub_sched_item.insn_id
for i, sub_sched_item in generate_sub_sched_items(
Expand All @@ -212,7 +203,7 @@ def get_insn_ids_for_block_at(


def find_used_inames_within(
kernel: LoopKernel, sched_index: int) -> AbstractSet[str]:
kernel: LoopKernel, sched_index: int) -> set[str]:
assert kernel.linearization is not None
sched_item = kernel.linearization[sched_index]

Expand All @@ -234,7 +225,7 @@ def find_used_inames_within(
return result


def find_loop_nest_with_map(kernel: LoopKernel) -> Mapping[str, AbstractSet[str]]:
def find_loop_nest_with_map(kernel: LoopKernel) -> Mapping[str, set[str]]:
"""Returns a dictionary mapping inames to other inames that are
always nested with them.
"""
Expand All @@ -257,11 +248,11 @@ def find_loop_nest_with_map(kernel: LoopKernel) -> Mapping[str, AbstractSet[str]
return result


def find_loop_nest_around_map(kernel: LoopKernel) -> Mapping[str, AbstractSet[str]]:
def find_loop_nest_around_map(kernel: LoopKernel) -> Mapping[str, set[str]]:
"""Returns a dictionary mapping inames to other inames that are
always nested around them.
"""
result: Dict[str, Set[str]] = {}
result: dict[str, set[str]] = {}

all_inames = kernel.all_inames()

Expand Down Expand Up @@ -299,14 +290,14 @@ def find_loop_nest_around_map(kernel: LoopKernel) -> Mapping[str, AbstractSet[st

def find_loop_insn_dep_map(
kernel: LoopKernel,
loop_nest_with_map: Mapping[str, AbstractSet[str]],
loop_nest_around_map: Mapping[str, AbstractSet[str]]
) -> Mapping[str, AbstractSet[str]]:
loop_nest_with_map: Mapping[str, Set[str]],
loop_nest_around_map: Mapping[str, Set[str]]
) -> Mapping[str, set[str]]:
"""Returns a dictionary mapping inames to other instruction ids that need to
be scheduled before the iname should be eligible for scheduling.
"""

result: Dict[str, Set[str]] = {}
result: dict[str, set[str]] = {}

from loopy.kernel.data import ConcurrentTag, IlpBaseTag
for insn in kernel.instructions:
Expand Down Expand Up @@ -372,7 +363,7 @@ def find_loop_insn_dep_map(


def group_insn_counts(kernel: LoopKernel) -> Mapping[str, int]:
result: Dict[str, int] = {}
result: dict[str, int] = {}

for insn in kernel.instructions:
for grp in insn.groups:
Expand All @@ -382,7 +373,7 @@ def group_insn_counts(kernel: LoopKernel) -> Mapping[str, int]:


def gen_dependencies_except(
kernel: LoopKernel, insn_id: str, except_insn_ids: AbstractSet[str]
kernel: LoopKernel, insn_id: str, except_insn_ids: Set[str]
) -> Iterator[str]:
insn = kernel.id_to_insn[insn_id]
for dep_id in insn.depends_on:
Expand All @@ -396,9 +387,9 @@ def gen_dependencies_except(


def get_priority_tiers(
wanted: AbstractSet[int],
priorities: AbstractSet[Sequence[int]]
) -> Iterator[AbstractSet[int]]:
wanted: Set[int],
priorities: Set[Sequence[int]]
) -> Iterator[set[int]]:
# Get highest priority tier candidates: These are the first inames
# of all the given priority constraints
candidates = set()
Expand Down Expand Up @@ -677,32 +668,32 @@ class SchedulerState:
order with instruction priorities as tie breaker.
"""
kernel: LoopKernel
loop_nest_around_map: Mapping[str, AbstractSet[str]]
loop_insn_dep_map: Mapping[str, AbstractSet[str]]
loop_nest_around_map: Mapping[str, set[str]]
loop_insn_dep_map: Mapping[str, set[str]]

breakable_inames: AbstractSet[str]
ilp_inames: AbstractSet[str]
vec_inames: AbstractSet[str]
concurrent_inames: AbstractSet[str]
breakable_inames: set[str]
ilp_inames: set[str]
vec_inames: set[str]
concurrent_inames: set[str]

insn_ids_to_try: Optional[AbstractSet[str]]
insn_ids_to_try: set[str] | None
active_inames: Sequence[str]
entered_inames: FrozenSet[str]
enclosing_subkernel_inames: Tuple[str, ...]
entered_inames: frozenset[str]
enclosing_subkernel_inames: tuple[str, ...]
schedule: Sequence[ScheduleItem]
scheduled_insn_ids: AbstractSet[str]
unscheduled_insn_ids: AbstractSet[str]
scheduled_insn_ids: frozenset[str]
unscheduled_insn_ids: set[str]
preschedule: Sequence[ScheduleItem]
prescheduled_insn_ids: AbstractSet[str]
prescheduled_inames: AbstractSet[str]
prescheduled_insn_ids: set[str]
prescheduled_inames: set[str]
may_schedule_global_barriers: bool
within_subkernel: bool
group_insn_counts: Mapping[str, int]
active_group_counts: Mapping[str, int]
insns_in_topologically_sorted_order: Sequence[InstructionBase]

@property
def last_entered_loop(self) -> Optional[str]:
def last_entered_loop(self) -> str | None:
if self.active_inames:
return self.active_inames[-1]
else:
Expand All @@ -718,7 +709,7 @@ def get_insns_in_topologically_sorted_order(
kernel: LoopKernel) -> Sequence[InstructionBase]:
from pytools.graph import compute_topological_order

rev_dep_map: Dict[str, Set[str]] = {
rev_dep_map: dict[str, set[str]] = {
not_none(insn.id): set() for insn in kernel.instructions}
for insn in kernel.instructions:
for dep in insn.depends_on:
Expand All @@ -733,7 +724,7 @@ def get_insns_in_topologically_sorted_order(
# Instead of returning these features as a key, we assign an id to
# each set of features to avoid comparing them which can be expensive.
insn_id_to_feature_id = {}
insn_features: Dict[Hashable, int] = {}
insn_features: dict[Hashable, int] = {}
for insn in kernel.instructions:
feature = (insn.within_inames, insn.groups, insn.conflicts_with_groups)
if feature not in insn_features:
Expand Down Expand Up @@ -890,7 +881,7 @@ def _get_outermost_diverging_inames(
tree: LoopTree,
within1: InameStrSet,
within2: InameStrSet
) -> Tuple[InameStr, InameStr]:
) -> tuple[InameStr, InameStr]:
"""
For loop nestings *within1* and *within2*, returns the first inames at which
the loops nests diverge in the loop nesting tree *tree*.
Expand Down Expand Up @@ -2180,7 +2171,7 @@ def __init__(self, kernel):
def generate_loop_schedules(
kernel: LoopKernel,
callables_table: CallablesTable,
debug_args: Optional[Dict[str, Any]] = None) -> Iterator[LoopKernel]:
debug_args: Mapping[str, Any] | None = None) -> Iterator[LoopKernel]:
"""
.. warning::

Expand Down Expand Up @@ -2236,7 +2227,7 @@ def _postprocess_schedule(kernel, callables_table, gen_sched):
def _generate_loop_schedules_inner(
kernel: LoopKernel,
callables_table: CallablesTable,
debug_args: Optional[Dict[str, Any]]) -> Iterator[LoopKernel]:
debug_args: Mapping[str, Any] | None) -> Iterator[LoopKernel]:
if debug_args is None:
debug_args = {}

Expand Down Expand Up @@ -2337,7 +2328,7 @@ def _generate_loop_schedules_inner(
get_insns_in_topologically_sorted_order(kernel)),
)

schedule_gen_kwargs: Dict[str, Any] = {}
schedule_gen_kwargs: dict[str, Any] = {}

def print_longest_dead_end():
if debug.interactive:
Expand Down Expand Up @@ -2402,7 +2393,7 @@ def print_longest_dead_end():


schedule_cache: WriteOncePersistentDict[
Tuple[LoopKernel, CallablesTable],
tuple[LoopKernel, CallablesTable],
LoopKernel
] = WriteOncePersistentDict(
"loopy-schedule-cache-v4-"+DATA_MODEL_VERSION,
Expand Down
23 changes: 12 additions & 11 deletions loopy/schedule/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,10 @@
THE SOFTWARE.
"""

from collections.abc import Hashable, Iterator, Sequence
from dataclasses import dataclass
from functools import cached_property
from typing import Generic, Hashable, Iterator, List, Optional, Sequence, Tuple, TypeVar
from typing import Generic, TypeVar

from immutables import Map

Expand Down Expand Up @@ -70,11 +71,11 @@ class Tree(Generic[NodeT]):
this allocates a new stack frame for each iteration of the operation.
"""

_parent_to_children: Map[NodeT, Tuple[NodeT, ...]]
_child_to_parent: Map[NodeT, Optional[NodeT]]
_parent_to_children: Map[NodeT, tuple[NodeT, ...]]
_child_to_parent: Map[NodeT, NodeT | None]

@staticmethod
def from_root(root: NodeT) -> "Tree[NodeT]":
def from_root(root: NodeT) -> Tree[NodeT]:
return Tree(Map({root: ()}),
Map({root: None}))

Expand All @@ -89,7 +90,7 @@ def root(self) -> NodeT:
return guess

@memoize_method
def ancestors(self, node: NodeT) -> Tuple[NodeT, ...]:
def ancestors(self, node: NodeT) -> tuple[NodeT, ...]:
"""
Returns a :class:`tuple` of nodes that are ancestors of *node*.
"""
Expand All @@ -104,15 +105,15 @@ def ancestors(self, node: NodeT) -> Tuple[NodeT, ...]:

return (parent,) + self.ancestors(parent)

def parent(self, node: NodeT) -> Optional[NodeT]:
def parent(self, node: NodeT) -> NodeT | None:
"""
Returns the parent of *node*.
"""
assert node in self

return self._child_to_parent[node]

def children(self, node: NodeT) -> Tuple[NodeT, ...]:
def children(self, node: NodeT) -> tuple[NodeT, ...]:
"""
Returns the children of *node*.
"""
Expand Down Expand Up @@ -150,7 +151,7 @@ def __contains__(self, node: NodeT) -> bool:
"""Return *True* if *node* is a node in the tree."""
return node in self._child_to_parent

def add_node(self, node: NodeT, parent: NodeT) -> "Tree[NodeT]":
def add_node(self, node: NodeT, parent: NodeT) -> Tree[NodeT]:
"""
Returns a :class:`Tree` with added node *node* having a parent
*parent*.
Expand All @@ -165,7 +166,7 @@ def add_node(self, node: NodeT, parent: NodeT) -> "Tree[NodeT]":
.set(node, ())),
self._child_to_parent.set(node, parent))

def replace_node(self, node: NodeT, new_node: NodeT) -> "Tree[NodeT]":
def replace_node(self, node: NodeT, new_node: NodeT) -> Tree[NodeT]:
"""
Returns a copy of *self* with *node* replaced with *new_node*.
"""
Expand Down Expand Up @@ -207,7 +208,7 @@ def replace_node(self, node: NodeT, new_node: NodeT) -> "Tree[NodeT]":
return Tree(parent_to_children_mut.finish(),
child_to_parent_mut.finish())

def move_node(self, node: NodeT, new_parent: Optional[NodeT]) -> "Tree[NodeT]":
def move_node(self, node: NodeT, new_parent: NodeT | None) -> Tree[NodeT]:
"""
Returns a copy of *self* with node *node* as a child of *new_parent*.
"""
Expand Down Expand Up @@ -262,7 +263,7 @@ def __str__(self) -> str:
├── D
└── E
"""
def rec(node: NodeT) -> List[str]:
def rec(node: NodeT) -> list[str]:
children_result = [rec(c) for c in self.children(node)]

def post_process_non_last_child(children: Sequence[str]) -> list[str]:
Expand Down
5 changes: 3 additions & 2 deletions loopy/transform/precompute.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,8 @@ def storage_axis_exprs(storage_axis_sources, args) -> Sequence[ExpressionT]:
# {{{ gather rule invocations

class RuleInvocationGatherer(RuleAwareIdentityMapper):
def __init__(self, rule_mapping_context, kernel, subst_name, subst_tag, within):
def __init__(self, rule_mapping_context, kernel, subst_name, subst_tag, within) \
-> None:
super().__init__(rule_mapping_context)

from loopy.symbolic import SubstitutionRuleExpander
Expand All @@ -167,7 +168,7 @@ def __init__(self, rule_mapping_context, kernel, subst_name, subst_tag, within):
self.subst_tag = subst_tag
self.within = within

self.access_descriptors: List[RuleAccessDescriptor] = []
self.access_descriptors: list[RuleAccessDescriptor] = []

def map_substitution(self, name, tag, arguments, expn_state):
process_me = name == self.subst_name
Expand Down