mirror of
https://github.com/TheAlgorithms/Python.git
synced 2025-07-10 06:10:33 +08:00
Enable ruff NPY002 rule (#11336)
This commit is contained in:
@ -153,7 +153,7 @@ class _DataSet:
|
||||
"""
|
||||
seed1, seed2 = random_seed.get_seed(seed)
|
||||
# If op level seed is not set, use whatever graph level seed is returned
|
||||
np.random.seed(seed1 if seed is None else seed2)
|
||||
self._rng = np.random.default_rng(seed1 if seed is None else seed2)
|
||||
dtype = dtypes.as_dtype(dtype).base_dtype
|
||||
if dtype not in (dtypes.uint8, dtypes.float32):
|
||||
raise TypeError("Invalid image dtype %r, expected uint8 or float32" % dtype)
|
||||
@ -211,7 +211,7 @@ class _DataSet:
|
||||
# Shuffle for the first epoch
|
||||
if self._epochs_completed == 0 and start == 0 and shuffle:
|
||||
perm0 = np.arange(self._num_examples)
|
||||
np.random.shuffle(perm0)
|
||||
self._rng.shuffle(perm0)
|
||||
self._images = self.images[perm0]
|
||||
self._labels = self.labels[perm0]
|
||||
# Go to the next epoch
|
||||
@ -225,7 +225,7 @@ class _DataSet:
|
||||
# Shuffle the data
|
||||
if shuffle:
|
||||
perm = np.arange(self._num_examples)
|
||||
np.random.shuffle(perm)
|
||||
self._rng.shuffle(perm)
|
||||
self._images = self.images[perm]
|
||||
self._labels = self.labels[perm]
|
||||
# Start next epoch
|
||||
|
Reference in New Issue
Block a user