Enable ruff NPY002 rule (#11336)

This commit is contained in:
Maxim Smolskiy
2024-04-01 22:39:31 +03:00
committed by GitHub
parent 39daaf8248
commit f8a948914b
9 changed files with 32 additions and 25 deletions

View File

@ -153,7 +153,7 @@ class _DataSet:
"""
seed1, seed2 = random_seed.get_seed(seed)
# If op level seed is not set, use whatever graph level seed is returned
np.random.seed(seed1 if seed is None else seed2)
self._rng = np.random.default_rng(seed1 if seed is None else seed2)
dtype = dtypes.as_dtype(dtype).base_dtype
if dtype not in (dtypes.uint8, dtypes.float32):
raise TypeError("Invalid image dtype %r, expected uint8 or float32" % dtype)
@ -211,7 +211,7 @@ class _DataSet:
# Shuffle for the first epoch
if self._epochs_completed == 0 and start == 0 and shuffle:
perm0 = np.arange(self._num_examples)
np.random.shuffle(perm0)
self._rng.shuffle(perm0)
self._images = self.images[perm0]
self._labels = self.labels[perm0]
# Go to the next epoch
@ -225,7 +225,7 @@ class _DataSet:
# Shuffle the data
if shuffle:
perm = np.arange(self._num_examples)
np.random.shuffle(perm)
self._rng.shuffle(perm)
self._images = self.images[perm]
self._labels = self.labels[perm]
# Start next epoch