Fix DeprecationWarning in local_weighted_learning.py (#9165)

Fix DeprecationWarning that occurs during build due to converting an
np.ndarray to a scalar implicitly
This commit is contained in:
Tianyi Zheng
2023-09-30 23:31:35 -04:00
committed by GitHub
parent aaf7195465
commit 5f8d1cb5c9

View File

@ -122,7 +122,7 @@ def local_weight_regression(
"""
y_pred = np.zeros(len(x_train)) # Initialize array of predictions
for i, item in enumerate(x_train):
y_pred[i] = item @ local_weight(item, x_train, y_train, tau)
y_pred[i] = np.dot(item, local_weight(item, x_train, y_train, tau))
return y_pred