Thanks to visit codestin.com
Credit goes to github.com

Skip to content
58 changes: 48 additions & 10 deletions src/datumaro/components/dataset_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
#
# SPDX-License-Identifier: MIT

from __future__ import annotations

import logging as log
from typing import Dict, Iterable, Iterator, List, Optional, Set, Tuple, Type, Union

Expand Down Expand Up @@ -628,15 +630,21 @@ def __getitem__(self, idx: int) -> DatasetItem:


class StreamSubset(IDataset):
def __init__(self, source: IDataset, subset: str) -> None:
def __init__(self, source: StreamDatasetStorage, subset: str) -> None:
if not source.is_stream:
raise ValueError("source should be a stream.")
self._source = source
self._subset = subset
self._length = None

def __iter__(self) -> Iterator[DatasetItem]:
for item in self._source:
if self._source._keeps_subsets_intact:
source = self._source._apply_stacked_transform(
self._source._source.get_subset(self._subset)
)
else:
source = self._source
for item in source:
if item.subset == self._subset:
yield item

Expand Down Expand Up @@ -688,7 +696,6 @@ def __init__(
if not source.is_stream:
raise ValueError("source should be a stream.")
self._subset_names = list(source.subsets().keys())
self._transform_ids_for_latest_subset_names = []
super().__init__(
source=source,
infos=infos,
Expand All @@ -697,6 +704,7 @@ def __init__(
ann_types=ann_types,
raise_on_malformed_transform=raise_on_malformed_transform,
)
self._keeps_subsets_intact = True

def is_cache_initialized(self) -> bool:
log.debug("This function has no effect on streaming.")
Expand All @@ -706,17 +714,22 @@ def init_cache(self) -> None:
log.debug("This function has no effect on streaming.")
pass

@property
def stacked_transform(self) -> IDataset:
def _apply_stacked_transform(self, source: IDataset):
if self._transforms:
transform = _StackedTransform(
self._source,
source,
self._transforms,
raise_on_malformed_transform=self._raise_on_malformed_transform,
)
self._drop_malformed_transforms(transform.malformed_transform_indices)
else:
transform = self._source
transform = source

return transform

@property
def stacked_transform(self) -> IDataset:
transform = self._apply_stacked_transform(self._source)

self._flush_changes = True
return transform
Expand Down Expand Up @@ -751,11 +764,34 @@ def remove(self, id: str, subset: Optional[str] = None) -> None:
def get_subset(self, name: str) -> IDataset:
return self.subsets()[name]

def _collect_subset_names(self):
assert not self._keeps_subsets_intact

item_generator = stacked_transform = self.stacked_transform
assert isinstance(stacked_transform, _StackedTransform)

if not stacked_transform.is_local:
self._keeps_subsets_intact = False
else:
self._keeps_subsets_intact = True

def _item_generator():
for item in self._source:
transformed_item = stacked_transform.transform_item(item)
if transformed_item is None:
continue
if item.subset != transformed_item.subset:
self._keeps_subsets_intact = False
yield transformed_item

item_generator = _item_generator()

return {item.subset for item in item_generator}

@property
def subset_names(self):
if self._transform_ids_for_latest_subset_names != [id(t) for t in self._transforms]:
self._subset_names = {item.subset for item in self}
self._transform_ids_for_latest_subset_names = [id(t) for t in self._transforms]
if self._subset_names is None:
self._subset_names = self._collect_subset_names()

return self._subset_names

Expand All @@ -764,6 +800,8 @@ def subsets(self) -> Dict[str, IDataset]:

def transform(self, method: Type[Transform], *args, **kwargs) -> None:
super().transform(method, *args, **kwargs)
self._keeps_subsets_intact = None if issubclass(method, ItemTransform) else False
self._subset_names = None

def get_annotated_items(self) -> int:
return super().get_annotated_items()
Expand Down
35 changes: 19 additions & 16 deletions src/datumaro/components/merge/extractor_merger.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def check_identicalness(seq: Sequence[T], raise_error_on_empty: bool = True) ->


class ExtractorMerger(DatasetBase):
"""A simple class to merge single-subset extractors."""
"""A simple class to merge not-intersecting single-subset extractors."""

def __init__(
self,
Expand All @@ -50,9 +50,12 @@ def __init__(

self._is_stream = check_identicalness([s.is_stream for s in sources])

self._subsets: Dict[str, List[SubsetBase]] = defaultdict(list)
subsets: Dict[str, List[SubsetBase]] = defaultdict(list)
for source in sources:
self._subsets[source.subset] += [source]
subsets[source.subset] += [source]
assert len(subsets[source.subset]) == 1

self._subsets = {subset_name: sources[0] for subset_name, sources in subsets.items()}

def infos(self) -> DatasetInfo:
return self._infos
Expand All @@ -61,23 +64,23 @@ def categories(self) -> CategoriesInfo:
return self._categories

def __iter__(self) -> Iterator[DatasetItem]:
for sources in self._subsets.values():
for source in sources:
yield from source
for subset in self._subsets.values():
yield from subset

def get_subset(self, name: str):
if name not in self._subsets:
raise KeyError(
"Unknown subset '%s', available subsets: %s" % (name, set(self._subsets))
)
return self._subsets[name]

def __len__(self) -> int:
return sum(len(source) for sources in self._subsets.values() for source in sources)
return sum(len(subset) for subset in self._subsets.values())

def get(self, id: str, subset: Optional[str] = None) -> Optional[DatasetItem]:
if subset is not None and (sources := self._subsets.get(subset, [])):
for source in sources:
if item := source.get(id, subset):
return item

for sources in self._subsets.values():
for source in sources:
if item := source.get(id=id, subset=source.subset):
return item
if source := self._subsets.get(subset):
if item := source.get(id, subset):
return item

return None

Expand Down
2 changes: 1 addition & 1 deletion src/datumaro/plugins/data_formats/coco/extractor_merger.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,6 @@ def __init__(self, sources: Sequence[_CocoBase]):
grouped_by_subset[s.subset] += [s]

self._subsets = {
subset: [COCOTaskMergedBase(sources, subset)]
subset: COCOTaskMergedBase(sources, subset)
for subset, sources in grouped_by_subset.items()
}
8 changes: 2 additions & 6 deletions tests/unit/data_formats/test_streaming_efficiency.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,7 @@ def __iter__(self):
# before yielded, references are only here
assert sys.getrefcount(item) == 2

# after yielded, there are more references (e.g. where it's yielded from)
# number of references doesn't have to increase in general,
# but it should due to how our code works
yield item
assert sys.getrefcount(item) > 2

# after next item yielded, ref count is 2 again - i.e. item was not saved anywhere
yield DatasetItem(
Expand Down Expand Up @@ -180,7 +176,7 @@ def test_streaming_importers(test_dir, export_format, fxt_dataset):
pass
assert init_counter.count == len(fxt_dataset) * 3

# subset access DOES ITERATE over ALL items in the dataset,
# subset access only iterates relevant items
# though item annotations and media data are not parsed before access
# (depending on the extractor support)
for subset in parsed_dataset.subsets().values():
Expand All @@ -192,4 +188,4 @@ def test_streaming_importers(test_dir, export_format, fxt_dataset):
assert isinstance(item.annotations, Annotations)
assert item.annotations_are_initialized

assert init_counter.count == len(fxt_dataset) * (3 + len(parsed_dataset.subsets()))
assert init_counter.count == len(fxt_dataset) * 4
119 changes: 116 additions & 3 deletions tests/unit/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
ProgressReporter,
)
from datumaro.components.dataset import DEFAULT_FORMAT, Dataset, StreamDataset, eager_mode
from datumaro.components.dataset_base import DatasetBase, DatasetItem, SubsetBase
from datumaro.components.dataset_base import DatasetBase, DatasetItem, IDataset, SubsetBase
from datumaro.components.dataset_item_storage import ItemStatus
from datumaro.components.environment import Environment
from datumaro.components.errors import (
Expand Down Expand Up @@ -2463,14 +2463,14 @@ def __iter__(self):
items_for_subsets[name][0], items_for_subsets[name][1], name
)

return SrcExtractor()
return SrcExtractor

@mark_requirement(Requirements.DATUM_GENERAL_REQ)
@pytest.mark.parametrize(
"items_for_subsets", [{"train": (1, 3)}, {"train": (1, 3), "val": (3, 6), "test": (9, 13)}]
)
def test_annotation_initializations(self, items_for_subsets):
extractor = self._make_extractor(items_for_subsets)
extractor = self._make_extractor(items_for_subsets)()
dataset_length = len(extractor)

dataset = StreamDataset.from_extractors(extractor)
Expand Down Expand Up @@ -2521,3 +2521,116 @@ def test_annotation_initializations(self, items_for_subsets):
subset_dataset = dataset.get_subset(subset).as_dataset()
len([item.annotations for item in subset_dataset])
assert extractor.ann_init_counter == dataset_length * 3

@staticmethod
def _make_extractor_with_subset_access(items_for_subsets: Dict[str, Tuple[int, int]]):
class SrcExtractor(StreamDatasetTest._make_extractor(items_for_subsets)):
def __init__(self):
super().__init__()
self.item_iterated_count = 0

def __iter__(self):
for subset_name in items_for_subsets:
yield from self.get_subset(subset_name)

def get_subset(self, name: str) -> IDataset:
assert name in items_for_subsets
parent = self

class _SubsetExtractor(SubsetBase):
def __iter__(self):
for item in super(SrcExtractor, parent).__iter__():
if item.subset == name:
yield item
parent.item_iterated_count += 1

@property
def is_stream(self):
return True

return _SubsetExtractor()

return SrcExtractor

@mark_requirement(Requirements.DATUM_GENERAL_REQ)
@pytest.mark.parametrize(
"items_for_subsets", [{"train": (1, 3)}, {"train": (1, 3), "val": (3, 6), "test": (9, 13)}]
)
def test_item_iteration_count_no_transform(self, items_for_subsets):
extractor = self._make_extractor_with_subset_access(items_for_subsets)()
dataset_length = len(extractor)
extractor.item_iterated_count = 0
dataset = StreamDataset.from_extractors(extractor)

# no need to iterate items to get subsets
dataset.subsets()
assert extractor.item_iterated_count == 0

# accessing items through subsets only iterates relevant items
for subset in dataset.subsets():
subset_dataset = dataset.get_subset(subset).as_dataset()
for _ in subset_dataset:
pass
assert extractor.item_iterated_count == dataset_length

assert extractor.ann_init_counter == 0

@mark_requirement(Requirements.DATUM_GENERAL_REQ)
@pytest.mark.parametrize(
"items_for_subsets", [{"train": (1, 3)}, {"train": (1, 3), "val": (3, 6), "test": (9, 13)}]
)
def test_item_iteration_count_item_transform(self, items_for_subsets):
extractor = self._make_extractor_with_subset_access(items_for_subsets)()
dataset_length = len(extractor)
extractor.item_iterated_count = 0
dataset = StreamDataset.from_extractors(extractor)

class TestTransform(ItemTransform):
def transform_item(self, item):
return item

dataset.transform(TestTransform)

# iterates items to collect subset names
dataset.subsets()
assert extractor.item_iterated_count == dataset_length
extractor.item_iterated_count = 0

# accessing items through subsets only iterates relevant items
for subset in dataset.subsets():
subset_dataset = dataset.get_subset(subset).as_dataset()
for _ in subset_dataset:
pass
assert extractor.item_iterated_count == dataset_length

assert extractor.ann_init_counter == 0

@mark_requirement(Requirements.DATUM_GENERAL_REQ)
@pytest.mark.parametrize(
"items_for_subsets", [{"train": (1, 3)}, {"train": (1, 3), "val": (3, 6), "test": (9, 13)}]
)
def test_item_iteration_count_general_transform(self, items_for_subsets):
extractor = self._make_extractor_with_subset_access(items_for_subsets)()
dataset_length = len(extractor)
extractor.item_iterated_count = 0
dataset = StreamDataset.from_extractors(extractor)

class TestTransform(Transform):
def __iter__(self):
yield from self._extractor

dataset.transform(TestTransform)

# iterates items to collect subset names
dataset.subsets()
assert extractor.item_iterated_count == dataset_length
extractor.item_iterated_count = 0

# iterates ALL items to access subset items
for subset in dataset.subsets():
subset_dataset = dataset.get_subset(subset).as_dataset()
for _ in subset_dataset:
pass
assert extractor.item_iterated_count == dataset_length * len(items_for_subsets)

assert extractor.ann_init_counter == 0
Loading