Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions src/datumaro/components/dataset_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,7 +650,10 @@ def __iter__(self) -> Iterator[DatasetItem]:

def __len__(self) -> int:
if self._length is None:
self._length = sum(1 for _ in self)
if self._source.subset_names == {self._subset}:
self._length = len(self._source)
else:
self._length = sum(1 for _ in self)
return self._length

def subsets(self) -> Dict[str, IDataset]:
Expand Down Expand Up @@ -695,7 +698,7 @@ def __init__(
):
if not source.is_stream:
raise ValueError("source should be a stream.")
self._subset_names = list(source.subsets().keys())
self._subset_names = set(source.subsets().keys())
super().__init__(
source=source,
infos=infos,
Expand Down Expand Up @@ -746,7 +749,10 @@ def __iter__(self) -> Iterator[DatasetItem]:

def __len__(self) -> int:
if self._length is None:
self._length = sum(1 for _ in self)
if not self._transforms:
self._length = len(self._source)
else:
self._length = sum(1 for _ in self)
return self._length

def put(self, item: DatasetItem) -> None:
Expand All @@ -764,7 +770,7 @@ def remove(self, id: str, subset: Optional[str] = None) -> None:
def get_subset(self, name: str) -> IDataset:
return self.subsets()[name]

def _collect_subset_names(self):
def _collect_subset_names(self) -> set[str]:
assert not self._keeps_subsets_intact

item_generator = stacked_transform = self.stacked_transform
Expand All @@ -789,7 +795,7 @@ def _item_generator():
return {item.subset for item in item_generator}

@property
def subset_names(self):
def subset_names(self) -> set[str]:
if self._subset_names is None:
self._subset_names = self._collect_subset_names()

Expand Down
17 changes: 16 additions & 1 deletion tests/unit/components/test_dataset_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from datumaro.components.annotation import AnnotationType
from datumaro.components.dataset_base import CategoriesInfo, DatasetInfo
from datumaro.components.dataset_storage import StreamDatasetStorage
from datumaro.components.dataset_storage import StreamDatasetStorage, StreamSubset
from datumaro.plugins.transforms import MapSubsets, RandomSplit, RemapLabels, Rename, UpdateInfos
from datumaro.util.definitions import DEFAULT_SUBSET_NAME

Expand Down Expand Up @@ -169,3 +169,18 @@ def test_mixed_transform(
# Check Rename
self._test_loop(fxt_stream_extractor, storage, n_calls, id_pattern="rename_{idx}")
assert fxt_stream_extractor.__iter__.call_count == n_calls


class StreamSubsetTest:
def test_single_subset_len_uses_dataset_len(self):
dataset_storage_mock = MagicMock(spec=StreamDatasetStorage)
dataset_storage_mock.subset_names = {"FOO"}

class StreamSubsetWrap(StreamSubset):
def __iter__(self):
# should not iterate items to get length
assert False

subset = StreamSubsetWrap(dataset_storage_mock, "FOO")
len(subset)
assert dataset_storage_mock.__len__.called
3 changes: 2 additions & 1 deletion tests/unit/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2562,8 +2562,9 @@ def test_item_iteration_count_no_transform(self, items_for_subsets):
extractor.item_iterated_count = 0
dataset = StreamDataset.from_extractors(extractor)

# no need to iterate items to get subsets
# no need to iterate items to get subsets or length
dataset.subsets()
len(dataset)
assert extractor.item_iterated_count == 0

# accessing items through subsets only iterates relevant items
Expand Down
Loading