Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
74 commits
Select commit Hold shift + click to select a range
1d7c685
added requirements-tensorflow-tpu.txt and tpu configuration in .kokoro
kharshith-k Jun 16, 2025
19b5e6b
updated .kokoro/github/ubuntu/tpu/build.sh with jax and torch backend…
kharshith-k Jun 16, 2025
d203ca3
Merge branch 'keras-team:master' into tf-tpu
kharshith-k Jun 18, 2025
f45e5d0
Changed the tpu CI config files path to .github from .kokoro
kharshith-k Jun 18, 2025
6771cc0
Added new job in .github/workflows/actions.yml to run TPU tests
kharshith-k Jun 18, 2025
87d36e7
fixed runs-on option in acvtions.yml for tpu_build job to run on self…
kharshith-k Jun 18, 2025
9901298
Added another runner in the actions TPU job
kharshith-k Jun 18, 2025
be97210
Update continuous.cfg
kharshith-k Jun 18, 2025
a1cd5c3
Update presubmit.cfg
kharshith-k Jun 18, 2025
c5e3a5c
Merge branch 'keras-team:master' into tf-tpu
kharshith-k Jun 23, 2025
f0ab676
Update actions.yml
kharshith-k Jun 23, 2025
09161d7
Developed Dockerfile for TPU build job in actions.yml
kharshith-k Jun 24, 2025
9a3948f
Merge branch 'keras-team:master' into tf-tpu
kharshith-k Jun 24, 2025
058fdff
Update actions.yml
kharshith-k Jun 24, 2025
d47e39e
Included few more runners in tpu_build job
kharshith-k Jun 26, 2025
a6a59d7
Merge branch 'keras-team:master' into tf-tpu
kharshith-k Jun 26, 2025
ba4f6ae
Using linux-x86-ct6e-44-1tpu
kharshith-k Jun 26, 2025
a5a3624
Modified requirement-commmon.txt and updated requirements-tensorflow-…
kharshith-k Jun 30, 2025
b9998af
Added Dtypes_TPU_tests.py and requirements-jax-tpu.txt
kharshith-k Jul 22, 2025
f68be97
Progress bar now handles `steps_per_execution`. (#21422)
hertschuh Jun 26, 2025
1018abf
Fix symbolic call of `logsumexp` with int axis. (#21428)
hertschuh Jun 27, 2025
0da77e4
Only allow deserialization of `KerasSaveable`s by module and name. (#…
hertschuh Jun 29, 2025
cb639c5
commented tensorflow deps
kharshith-k Jul 2, 2025
c0d1743
Added log of dtypes_test_tpu.py and the test script for the same
kharshith-k Jul 2, 2025
306e6e7
modified dtypes_test_tpu.py as per pre-commit standards
kharshith-k Jul 2, 2025
4e584fc
Added TPU initiaization and teardown functionalities in conftest.py, …
kharshith-k Jul 3, 2025
bb09e95
Added dtypes_test_TPU.py and dtypes_new_test.py, modified conftest.py
kharshith-k Jul 9, 2025
8a63d09
Added Dcokerfile and tests list command
kharshith-k Jul 23, 2025
4651454
Updated Dockerfile
kharshith-k Jul 28, 2025
40af241
Restored Dockerfile to previous changes
kharshith-k Jul 28, 2025
64420d5
updated actions.yml file to install and configure docker engine on se…
kharshith-k Jul 28, 2025
da84de5
Merge branch 'keras-team:master' into tf-tpu
kharshith-k Jul 28, 2025
d69277d
updated actions.yml file to include container option
kharshith-k Jul 28, 2025
1c307fc
updated actions.yml file to include container option without volume b…
kharshith-k Jul 28, 2025
693886b
updated actions.yml file to change TPU
kharshith-k Jul 28, 2025
e74b851
Updated container path in build-and-test-on-tpu job
kharshith-k Jul 29, 2025
d31b3c4
seperated TPU workflow from actions.yml
kharshith-k Jul 29, 2025
a70d19e
updated trigger condition for TPU tests workflow
kharshith-k Jul 29, 2025
5f5b609
updated container usage configuration for TPU tests workflow
kharshith-k Jul 29, 2025
72e729f
updated env vars for TPU tests workflow
kharshith-k Jul 29, 2025
e129299
updated env vars parsing syntax in TPU tests workflow
kharshith-k Jul 29, 2025
3fe5b57
updated env vars syntax in TPU tests workflow
kharshith-k Jul 29, 2025
10df307
updated env vars syntax in TPU tests workflow
kharshith-k Jul 29, 2025
dd21e09
updated env vars syntax in TPU tests workflow
kharshith-k Jul 29, 2025
328628f
updated env vars syntax in TPU tests workflow
kharshith-k Jul 29, 2025
01f0c17
updated image name in TPU tests workflow
kharshith-k Jul 29, 2025
3e41c37
updated image name with generic ubuntu image
kharshith-k Jul 29, 2025
5e55c2c
updated tpu-tests to use ghcr
kharshith-k Jul 29, 2025
ea9ff88
updated tpu-tests to store built image as local tar
kharshith-k Jul 29, 2025
6d92aa9
updated image name from ubuntu:22.04 to docker:24.0-cli in tpu tests …
kharshith-k Jul 29, 2025
3c75bf8
updated image name from docker:24.0-cli to ubuntu:22.04 in tpu tests…
kharshith-k Jul 29, 2025
1589a75
added volume mount from host in load-and-test-job
kharshith-k Jul 29, 2025
36bd682
Merge branch 'keras-team:master' into tf-tpu
kharshith-k Jul 29, 2025
04112cf
Reverted tpu-tests.yml to version using ghcr.io for image storage
kharshith-k Jul 29, 2025
87d7ad8
Removed custom dtypes_test files for TPU testing and restored origina…
kharshith-k Aug 12, 2025
6cb097c
Updated tpu-tests.yml to pull image from GCP artifact registry
kharshith-k Aug 12, 2025
4829f1b
Resolved conflicts in actions.yml
kharshith-k Aug 12, 2025
a2eb306
Added a workflow to check service accounts associated with self hoste…
kharshith-k Aug 12, 2025
23579c4
Made find_sa.yml specific to linux-x86-ct6e-44-1tpu
kharshith-k Aug 12, 2025
dac6433
Added container tag to find_sa.yml
kharshith-k Aug 12, 2025
05461c1
Checking SA for linux-x86-ct5lp-112-4tpu
kharshith-k Aug 12, 2025
078dcee
Checking SA for linux-x86-ct6e-44-1tpu-nxgm7-runner-vb87c
kharshith-k Aug 12, 2025
016c68d
Using SA for auth in tpu-tests
kharshith-k Aug 12, 2025
02657f0
Updated SA with container tag for auth in tpu-tests
kharshith-k Aug 12, 2025
7167952
Added docker socket mount test
kharshith-k Aug 12, 2025
543cf65
Updated tpu-tests to just pull and test the image from artifact regis…
kharshith-k Aug 14, 2025
a2401c0
Added pytest command to the workflow
kharshith-k Aug 14, 2025
a98c748
added grain installation command
kharshith-k Aug 14, 2025
71c5b8b
Pruned unwanted files
kharshith-k Aug 19, 2025
5522a2b
included grain in requirements.txt
kharshith-k Aug 19, 2025
7173c6d
Updated tpu-tests.yml to use python image and explicitly install spec…
kharshith-k Aug 22, 2025
a7dc789
Renamed tpu-tests to tpu-tests-jax and logging TPU device kind
kharshith-k Aug 22, 2025
e509e6d
Added a step to check gcloud installation
kharshith-k Aug 22, 2025
a7ec63b
Running pytest on generic tpu workflow
kharshith-k Aug 25, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Progress bar now handles steps_per_execution. (#21422)
Progress bar would always report the starting batch + 1 at the end of the batch. Now it takes into account `steps_per_execution` for the last batch reported.

Fixes #20861
  • Loading branch information
hertschuh authored and kharshith-k committed Jul 23, 2025
commit f68be9775658df86f52c059beb16611f22c2418a
22 changes: 12 additions & 10 deletions keras/src/backend/jax/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,9 +408,9 @@ def fit(

self._jax_state_synced = True
with epoch_iterator.catch_stop_iteration():
for step, iterator in epoch_iterator:
for begin_step, end_step, iterator in epoch_iterator:
# Callbacks
callbacks.on_train_batch_begin(step)
callbacks.on_train_batch_begin(begin_step)

# Train step
if self._jax_state_synced:
Expand Down Expand Up @@ -441,7 +441,7 @@ def fit(
"metrics_variables": metrics_variables,
}
# Dispatch callbacks. This takes care of async dispatch.
callbacks.on_train_batch_end(step, logs)
callbacks.on_train_batch_end(end_step, logs)

if self.stop_training:
# Stop training if a callback has set
Expand Down Expand Up @@ -569,8 +569,8 @@ def evaluate(

self._jax_state_synced = True
with epoch_iterator.catch_stop_iteration():
for step, iterator in epoch_iterator:
callbacks.on_test_batch_begin(step)
for begin_step, end_step, iterator in epoch_iterator:
callbacks.on_test_batch_begin(begin_step)

if self._jax_state_synced:
# The state may have been synced by a callback.
Expand Down Expand Up @@ -600,7 +600,7 @@ def evaluate(
}

# Dispatch callbacks. This takes care of async dispatch.
callbacks.on_test_batch_end(step, logs)
callbacks.on_test_batch_end(end_step, logs)

if self.stop_evaluating:
break
Expand Down Expand Up @@ -633,7 +633,7 @@ def predict(

if not all(layer.built for layer in self._flatten_layers()):
# Build the model on one batch of data.
for _, iterator in epoch_iterator:
for _, _, iterator in epoch_iterator:
# Build model
x, _, _ = data_adapter_utils.unpack_x_y_sample_weight(
next(iterator)
Expand Down Expand Up @@ -677,8 +677,8 @@ def append_to_outputs(batch_outputs, outputs):
outputs = None
non_trainable_variables = None
with epoch_iterator.catch_stop_iteration():
for step, iterator in epoch_iterator:
callbacks.on_predict_batch_begin(step)
for begin_step, end_step, iterator in epoch_iterator:
callbacks.on_predict_batch_begin(begin_step)
if self._jax_state_synced:
# The state may have been synced by a callback.
state = self._get_jax_state(
Expand All @@ -701,7 +701,9 @@ def append_to_outputs(batch_outputs, outputs):
outputs = append_to_outputs(batch_outputs, outputs)

# Dispatch callbacks. This takes care of async dispatch.
callbacks.on_predict_batch_end(step, {"outputs": batch_outputs})
callbacks.on_predict_batch_end(
end_step, {"outputs": batch_outputs}
)

if self.stop_predicting:
break
Expand Down
14 changes: 7 additions & 7 deletions keras/src/backend/numpy/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,11 +211,11 @@ def append_to_outputs(batch_outputs, outputs):
self.stop_predicting = False
callbacks.on_predict_begin()
outputs = None
for step, data in epoch_iterator:
callbacks.on_predict_batch_begin(step)
for begin_step, end_step, data in epoch_iterator:
callbacks.on_predict_batch_begin(begin_step)
batch_outputs = self.predict_function(data)
outputs = append_to_outputs(batch_outputs, outputs)
callbacks.on_predict_batch_end(step, {"outputs": batch_outputs})
callbacks.on_predict_batch_end(end_step, {"outputs": batch_outputs})
if self.stop_predicting:
break
callbacks.on_predict_end()
Expand Down Expand Up @@ -255,7 +255,7 @@ def evaluate(

if not all(layer.built for layer in self._flatten_layers()):
# Build the model on one batch of data.
for _, data in epoch_iterator:
for _, _, data in epoch_iterator:
data_batch = data[0]
self._symbolic_build(data_batch)
break
Expand All @@ -276,10 +276,10 @@ def evaluate(
callbacks.on_test_begin()
logs = {}
self.reset_metrics()
for step, data in epoch_iterator:
callbacks.on_test_batch_begin(step)
for begin_step, end_step, data in epoch_iterator:
callbacks.on_test_batch_begin(begin_step)
logs = self.test_function(data)
callbacks.on_test_batch_end(step, logs)
callbacks.on_test_batch_end(end_step, logs)
if self.stop_evaluating:
break
logs = self._get_metrics_result_or_logs(logs)
Expand Down
6 changes: 3 additions & 3 deletions keras/src/backend/openvino/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,11 +213,11 @@ def append_to_outputs(batch_outputs, outputs):
self.stop_predicting = False
callbacks.on_predict_begin()
outputs = None
for step, data in epoch_iterator.enumerate_epoch():
callbacks.on_predict_batch_begin(step)
for begin_step, end_step, data in epoch_iterator.enumerate_epoch():
callbacks.on_predict_batch_begin(begin_step)
batch_outputs = self.predict_function(data)
outputs = append_to_outputs(batch_outputs, outputs)
callbacks.on_predict_batch_end(step, {"outputs": batch_outputs})
callbacks.on_predict_batch_end(end_step, {"outputs": batch_outputs})
if self.stop_predicting:
break
callbacks.on_predict_end()
Expand Down
2 changes: 1 addition & 1 deletion keras/src/backend/tensorflow/distribute_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def test_epoch_iterator(self):
distribute_strategy=strategy,
)
steps_seen = []
for step, data_iterator in epoch_iterator:
for step, _, data_iterator in epoch_iterator:
steps_seen.append(step)
batch = next(data_iterator)
self.assertEqual(len(batch), 3)
Expand Down
22 changes: 12 additions & 10 deletions keras/src/backend/tensorflow/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,10 +372,10 @@ def fit(
self.reset_metrics()
callbacks.on_epoch_begin(epoch)
with epoch_iterator.catch_stop_iteration():
for step, iterator in epoch_iterator:
callbacks.on_train_batch_begin(step)
for begin_step, end_step, iterator in epoch_iterator:
callbacks.on_train_batch_begin(begin_step)
logs = self.train_function(iterator)
callbacks.on_train_batch_end(step, logs)
callbacks.on_train_batch_end(end_step, logs)
if self.stop_training:
break

Expand Down Expand Up @@ -484,10 +484,10 @@ def evaluate(
logs = {}
self.reset_metrics()
with epoch_iterator.catch_stop_iteration():
for step, iterator in epoch_iterator:
callbacks.on_test_batch_begin(step)
for begin_step, end_step, iterator in epoch_iterator:
callbacks.on_test_batch_begin(begin_step)
logs = self.test_function(iterator)
callbacks.on_test_batch_end(step, logs)
callbacks.on_test_batch_end(end_step, logs)
if self.stop_evaluating:
break
logs = self._get_metrics_result_or_logs(logs)
Expand Down Expand Up @@ -560,12 +560,14 @@ def get_data(iterator):
callbacks.on_predict_begin()
outputs = None
with epoch_iterator.catch_stop_iteration():
for step, iterator in epoch_iterator:
callbacks.on_predict_batch_begin(step)
for begin_step, end_step, iterator in epoch_iterator:
callbacks.on_predict_batch_begin(begin_step)
data = get_data(iterator)
batch_outputs = self.predict_function(data)
outputs = append_to_outputs(batch_outputs, outputs)
callbacks.on_predict_batch_end(step, {"outputs": batch_outputs})
callbacks.on_predict_batch_end(
end_step, {"outputs": batch_outputs}
)
if self.stop_predicting:
break
callbacks.on_predict_end()
Expand Down Expand Up @@ -696,7 +698,7 @@ def _maybe_symbolic_build(self, iterator=None, data_batch=None):
# Unlike jax/torch iterator, tf iterator returns an iterator instead
# of data batch in `iterator`.
if iterator is not None:
for _, it in iterator:
for _, _, it in iterator:
maybe_distributed_data_batch = next(it)
has_distributed_values = tree.map_structure(
lambda x: isinstance(x, tf.distribute.DistributedValues),
Expand Down
18 changes: 9 additions & 9 deletions keras/src/backend/torch/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,14 +256,14 @@ def fit(
self.train()

logs = {}
for step, data in epoch_iterator:
for begin_step, end_step, data in epoch_iterator:
# Callbacks
callbacks.on_train_batch_begin(step)
callbacks.on_train_batch_begin(begin_step)

logs = self.train_function(data)

# Callbacks
callbacks.on_train_batch_end(step, logs)
callbacks.on_train_batch_end(end_step, logs)
if self.stop_training:
break

Expand Down Expand Up @@ -374,10 +374,10 @@ def evaluate(
callbacks.on_test_begin()
logs = {}
self.reset_metrics()
for step, data in epoch_iterator:
callbacks.on_test_batch_begin(step)
for begin_step, end_step, data in epoch_iterator:
callbacks.on_test_batch_begin(begin_step)
logs = self.test_function(data)
callbacks.on_test_batch_end(step, logs)
callbacks.on_test_batch_end(end_step, logs)
if self.stop_evaluating:
break
logs = self._get_metrics_result_or_logs(logs)
Expand Down Expand Up @@ -433,11 +433,11 @@ def append_to_outputs(batch_outputs, outputs):
self.stop_predicting = False
callbacks.on_predict_begin()
outputs = None
for step, data in epoch_iterator:
callbacks.on_predict_batch_begin(step)
for begin_step, end_step, data in epoch_iterator:
callbacks.on_predict_batch_begin(begin_step)
batch_outputs = self.predict_function(data)
outputs = append_to_outputs(batch_outputs, outputs)
callbacks.on_predict_batch_end(step, {"outputs": batch_outputs})
callbacks.on_predict_batch_end(end_step, {"outputs": batch_outputs})
if self.stop_predicting:
break
callbacks.on_predict_end()
Expand Down
18 changes: 11 additions & 7 deletions keras/src/trainers/epoch_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,11 @@ def _enumerate_iterator(self):
self._interrupted_warning()
break
self._steps_seen += self.steps_per_execution
yield step, self._current_iterator
yield (
step,
step + self.steps_per_execution - 1,
self._current_iterator,
)
if self._num_batches and self._steps_seen >= self._num_batches:
self._current_iterator = iter(self._get_iterator())
self._steps_seen = 0
Expand All @@ -126,7 +130,7 @@ def _enumerate_iterator(self):
while True:
step += self.steps_per_execution
self._steps_seen = step + self.steps_per_execution
yield step, iterator
yield step, step + self.steps_per_execution - 1, iterator
self.data_adapter.on_epoch_end()

def __iter__(self):
Expand All @@ -135,19 +139,19 @@ def __iter__(self):

def __next__(self):
buffer = []
step, iterator = next(self._epoch_iterator)
begin_step, end_step, iterator = next(self._epoch_iterator)
with self.catch_stop_iteration():
for _ in range(self.steps_per_execution):
data = next(iterator)
buffer.append(data)
return step, buffer
return begin_step, end_step, buffer
if buffer:
return step, buffer
return begin_step, end_step, buffer
raise StopIteration

def enumerate_epoch(self):
for step, data in self:
yield step, data
for begin_step, end_step, data in self:
yield begin_step, end_step, data

@contextlib.contextmanager
def catch_stop_iteration(self):
Expand Down
11 changes: 6 additions & 5 deletions keras/src/trainers/epoch_iterator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,10 @@ def test_basic_flow(self, call_type):
generator = iterator
else:
generator = iterator.enumerate_epoch()
for step, batch in generator:
for begin_step, end_step, batch in generator:
batch = batch[0]
steps_seen.append(step)
steps_seen.append(begin_step)
self.assertEqual(begin_step, end_step)
self.assertEqual(len(batch), 3)
self.assertIsInstance(batch[0], np.ndarray)
self.assertEqual(steps_seen, [0, 1, 2, 3, 4, 5, 6])
Expand All @@ -52,7 +53,7 @@ def test_insufficient_data(self):
)
steps_seen = []
with pytest.warns(match="Your input ran out of data"):
for step, _ in iterator:
for step, _, _ in iterator:
steps_seen.append(step)
self.assertLen(steps_seen, steps_per_epoch - 2)

Expand Down Expand Up @@ -96,7 +97,7 @@ def __getitem__(self, idx):
torch_dataset, batch_size=8, shuffle=True
)
iterator = epoch_iterator.EpochIterator(torch_dataloader)
for _, batch in iterator:
for _, _, batch in iterator:
batch = batch[0]
self.assertEqual(batch[0].shape, (8, 2))
self.assertEqual(batch[1].shape, (8, 1))
Expand Down Expand Up @@ -226,7 +227,7 @@ def on_epoch_end(self):

num_epochs = 5
for epoch in range(num_epochs):
for step, batch in epoch_iter:
for _, _, _ in epoch_iter:
pass

self.assertAllEqual(ds.tracker, [1, 2] * num_epochs)
2 changes: 1 addition & 1 deletion keras/src/trainers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1072,7 +1072,7 @@ def to_symbolic_input(v):
)

if data_batch is None:
for _, data_or_iterator in iterator:
for _, _, data_or_iterator in iterator:
if isinstance(data_or_iterator, (list, tuple)):
data_batch = data_or_iterator[0]
else:
Expand Down
Loading