mirror of
https://github.com/TheAlgorithms/Java.git
synced 2025-07-07 01:35:16 +08:00
Add K-D Tree (#3210)
This commit is contained in:

committed by
GitHub

parent
854b900257
commit
05660dae92
401
src/main/java/com/thealgorithms/datastructures/trees/KDTree.java
Normal file
401
src/main/java/com/thealgorithms/datastructures/trees/KDTree.java
Normal file
@ -0,0 +1,401 @@
|
|||||||
|
package com.thealgorithms.datastructures.trees;
|
||||||
|
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.Comparator;
|
||||||
|
import java.util.Objects;
|
||||||
|
import java.util.Optional;
|
||||||
|
|
||||||
|
/*
|
||||||
|
* K-D Tree Implementation
|
||||||
|
* Wikipedia: https://en.wikipedia.org/wiki/K-d_tree
|
||||||
|
*
|
||||||
|
* Author: Amir Hosseini (https://github.com/itsamirhn)
|
||||||
|
*
|
||||||
|
* */
|
||||||
|
|
||||||
|
public class KDTree {
|
||||||
|
|
||||||
|
private Node root;
|
||||||
|
|
||||||
|
private final int k; // Dimensions of the points
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Constructor for empty KDTree
|
||||||
|
*
|
||||||
|
* @param k Number of dimensions
|
||||||
|
*/
|
||||||
|
KDTree(int k) {
|
||||||
|
this.k = k;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Builds the KDTree from the specified points
|
||||||
|
*
|
||||||
|
* @param points Array of initial points
|
||||||
|
*/
|
||||||
|
KDTree(Point[] points) {
|
||||||
|
if (points.length == 0) throw new IllegalArgumentException("Points array cannot be empty");
|
||||||
|
this.k = points[0].getDimension();
|
||||||
|
for (Point point : points) if (point.getDimension() != k) throw new IllegalArgumentException("Points must have the same dimension");
|
||||||
|
this.root = build(points, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Builds the KDTree from the specified coordinates of the points
|
||||||
|
*
|
||||||
|
* @param pointsCoordinates Array of initial points coordinates
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
KDTree(int[][] pointsCoordinates) {
|
||||||
|
if (pointsCoordinates.length == 0) throw new IllegalArgumentException("Points array cannot be empty");
|
||||||
|
this.k = pointsCoordinates[0].length;
|
||||||
|
Point[] points = Arrays.stream(pointsCoordinates)
|
||||||
|
.map(Point::new)
|
||||||
|
.toArray(Point[]::new);
|
||||||
|
for (Point point : points) if (point.getDimension() != k) throw new IllegalArgumentException("Points must have the same dimension");
|
||||||
|
this.root = build(points, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
static class Point {
|
||||||
|
int[] coordinates;
|
||||||
|
|
||||||
|
public int getCoordinate(int i) {
|
||||||
|
return coordinates[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
public int getDimension() {
|
||||||
|
return coordinates.length;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Point(int[] coordinates) {
|
||||||
|
this.coordinates = coordinates;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean equals(Object obj) {
|
||||||
|
if (obj instanceof Point other) {
|
||||||
|
if (other.getDimension() != this.getDimension()) return false;
|
||||||
|
return Arrays.equals(other.coordinates, this.coordinates);
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String toString() {
|
||||||
|
return Arrays.toString(coordinates);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Find the comparable distance between two points (without SQRT)
|
||||||
|
*
|
||||||
|
* @param p1 First point
|
||||||
|
* @param p2 Second point
|
||||||
|
*
|
||||||
|
* @return The comparable distance between the two points
|
||||||
|
*/
|
||||||
|
static public int comparableDistance(Point p1, Point p2) {
|
||||||
|
int distance = 0;
|
||||||
|
for (int i = 0; i < p1.getDimension(); i++) {
|
||||||
|
int t = p1.getCoordinate(i) - p2.getCoordinate(i);
|
||||||
|
distance += t * t;
|
||||||
|
}
|
||||||
|
return distance;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Find the comparable distance between two points with ignoring specified axis
|
||||||
|
*
|
||||||
|
* @param p1 First point
|
||||||
|
* @param p2 Second point
|
||||||
|
* @param axis The axis to ignore
|
||||||
|
*
|
||||||
|
* @return The distance between the two points
|
||||||
|
*/
|
||||||
|
static public int comparableDistanceExceptAxis(Point p1, Point p2, int axis) {
|
||||||
|
int distance = 0;
|
||||||
|
for (int i = 0; i < p1.getDimension(); i++) {
|
||||||
|
if (i == axis) continue;
|
||||||
|
int t = p1.getCoordinate(i) - p2.getCoordinate(i);
|
||||||
|
distance += t * t;
|
||||||
|
}
|
||||||
|
return distance;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
static class Node {
|
||||||
|
private Point point;
|
||||||
|
private int axis; // 0 for x, 1 for y, 2 for z, etc.
|
||||||
|
|
||||||
|
private Node left = null; // Left child
|
||||||
|
private Node right = null; // Right child
|
||||||
|
|
||||||
|
Node(Point point, int axis) {
|
||||||
|
this.point = point;
|
||||||
|
this.axis = axis;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Point getPoint() {
|
||||||
|
return point;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Node getLeft() {
|
||||||
|
return left;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Node getRight() {
|
||||||
|
return right;
|
||||||
|
}
|
||||||
|
|
||||||
|
public int getAxis() {
|
||||||
|
return axis;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get the nearest child according to the specified point
|
||||||
|
*
|
||||||
|
* @param point The point to find the nearest child to
|
||||||
|
*
|
||||||
|
* @return The nearest child Node
|
||||||
|
*/
|
||||||
|
public Node getNearChild(Point point) {
|
||||||
|
if (point.getCoordinate(axis) < this.point.getCoordinate(axis)) return left;
|
||||||
|
else return right;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get the farthest child according to the specified point
|
||||||
|
*
|
||||||
|
* @param point The point to find the farthest child to
|
||||||
|
*
|
||||||
|
* @return The farthest child Node
|
||||||
|
*/
|
||||||
|
public Node getFarChild(Point point) {
|
||||||
|
if (point.getCoordinate(axis) < this.point.getCoordinate(axis)) return right;
|
||||||
|
else return left;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get the node axis coordinate of point
|
||||||
|
*
|
||||||
|
* @return The axis coordinate of the point
|
||||||
|
*/
|
||||||
|
public int getAxisCoordinate() {
|
||||||
|
return point.getCoordinate(axis);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public Node getRoot() {
|
||||||
|
return root;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Builds the KDTree from the specified points
|
||||||
|
*
|
||||||
|
* @param points Array of initial points
|
||||||
|
* @param depth The current depth of the tree
|
||||||
|
*
|
||||||
|
* @return The root of the KDTree
|
||||||
|
*/
|
||||||
|
private Node build(Point[] points, int depth) {
|
||||||
|
if (points.length == 0) return null;
|
||||||
|
int axis = depth % k;
|
||||||
|
if (points.length == 1) return new Node(points[0], axis);
|
||||||
|
Arrays.sort(points, Comparator.comparingInt(o -> o.getCoordinate(axis)));
|
||||||
|
int median = points.length >> 1;
|
||||||
|
Node node = new Node(points[median], axis);
|
||||||
|
node.left = build(Arrays.copyOfRange(points, 0, median), depth + 1);
|
||||||
|
node.right = build(Arrays.copyOfRange(points, median + 1, points.length), depth + 1);
|
||||||
|
return node;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Insert a point into the KDTree
|
||||||
|
*
|
||||||
|
* @param point The point to insert
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
public void insert(Point point) {
|
||||||
|
if (point.getDimension() != k) throw new IllegalArgumentException("Point has wrong dimension");
|
||||||
|
root = insert(root, point, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Insert a point into a subtree
|
||||||
|
*
|
||||||
|
* @param root The root of the subtree
|
||||||
|
* @param point The point to insert
|
||||||
|
* @param depth The current depth of the tree
|
||||||
|
*
|
||||||
|
* @return The root of the KDTree
|
||||||
|
*/
|
||||||
|
private Node insert(Node root, Point point, int depth) {
|
||||||
|
int axis = depth % k;
|
||||||
|
if (root == null) return new Node(point, axis);
|
||||||
|
if (point.getCoordinate(axis) < root.getAxisCoordinate()) root.left = insert(root.left, point, depth + 1);
|
||||||
|
else root.right = insert(root.right, point, depth + 1);
|
||||||
|
|
||||||
|
return root;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Search for Node corresponding to the specified point in the KDTree
|
||||||
|
*
|
||||||
|
* @param point The point to search for
|
||||||
|
*
|
||||||
|
* @return The Node corresponding to the specified point
|
||||||
|
*/
|
||||||
|
public Optional<Node> search(Point point) {
|
||||||
|
if (point.getDimension() != k) throw new IllegalArgumentException("Point has wrong dimension");
|
||||||
|
return search(root, point);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Search for Node corresponding to the specified point in a subtree
|
||||||
|
*
|
||||||
|
* @param root The root of the subtree to search in
|
||||||
|
* @param point The point to search for
|
||||||
|
*
|
||||||
|
* @return The Node corresponding to the specified point
|
||||||
|
*/
|
||||||
|
public Optional<Node> search(Node root, Point point) {
|
||||||
|
if (root == null) return Optional.empty();
|
||||||
|
if (root.point.equals(point)) return Optional.of(root);
|
||||||
|
return search(root.getNearChild(point), point);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Find a point with minimum value in specified axis in the KDTree
|
||||||
|
*
|
||||||
|
* @param axis The axis to find the minimum value in
|
||||||
|
*
|
||||||
|
* @return The point with minimum value in the specified axis
|
||||||
|
*/
|
||||||
|
public Point findMin(int axis) {
|
||||||
|
return findMin(root, axis).point;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Find a point with minimum value in specified axis in a subtree
|
||||||
|
*
|
||||||
|
* @param root The root of the subtree to search in
|
||||||
|
* @param axis The axis to find the minimum value in
|
||||||
|
*
|
||||||
|
* @return The Node with minimum value in the specified axis of the point
|
||||||
|
*/
|
||||||
|
public Node findMin(Node root, int axis) {
|
||||||
|
if (root == null) return null;
|
||||||
|
if (root.getAxis() == axis) {
|
||||||
|
if (root.left == null) return root;
|
||||||
|
return findMin(root.left, axis);
|
||||||
|
} else {
|
||||||
|
Node left = findMin(root.left, axis);
|
||||||
|
Node right = findMin(root.right, axis);
|
||||||
|
Node[] candidates = {left, root, right};
|
||||||
|
return Arrays.stream(candidates)
|
||||||
|
.filter(Objects::nonNull)
|
||||||
|
.min(Comparator.comparingInt(a -> a.point.getCoordinate(axis))).orElse(null);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Find a point with maximum value in specified axis in the KDTree
|
||||||
|
*
|
||||||
|
* @param axis The axis to find the maximum value in
|
||||||
|
*
|
||||||
|
* @return The point with maximum value in the specified axis
|
||||||
|
*/
|
||||||
|
public Point findMax(int axis) {
|
||||||
|
return findMax(root, axis).point;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Find a point with maximum value in specified axis in a subtree
|
||||||
|
*
|
||||||
|
* @param root The root of the subtree to search in
|
||||||
|
* @param axis The axis to find the maximum value in
|
||||||
|
*
|
||||||
|
* @return The Node with maximum value in the specified axis of the point
|
||||||
|
*/
|
||||||
|
public Node findMax(Node root, int axis) {
|
||||||
|
if (root == null) return null;
|
||||||
|
if (root.getAxis() == axis) {
|
||||||
|
if (root.right == null) return root;
|
||||||
|
return findMax(root.right, axis);
|
||||||
|
} else {
|
||||||
|
Node left = findMax(root.left, axis);
|
||||||
|
Node right = findMax(root.right, axis);
|
||||||
|
Node[] candidates = {left, root, right};
|
||||||
|
return Arrays.stream(candidates)
|
||||||
|
.filter(Objects::nonNull)
|
||||||
|
.max(Comparator.comparingInt(a -> a.point.getCoordinate(axis))).orElse(null);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Delete the node with the given point.
|
||||||
|
*
|
||||||
|
* @param point the point to delete
|
||||||
|
* */
|
||||||
|
public void delete(Point point) {
|
||||||
|
Node node = search(point).orElseThrow(() -> new IllegalArgumentException("Point not found"));
|
||||||
|
root = delete(root, node);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Delete the specified node from a subtree.
|
||||||
|
*
|
||||||
|
* @param root The root of the subtree to delete from
|
||||||
|
* @param node The node to delete
|
||||||
|
*
|
||||||
|
* @return The new root of the subtree
|
||||||
|
*/
|
||||||
|
private Node delete(Node root, Node node) {
|
||||||
|
if (root == null) return null;
|
||||||
|
if (root.equals(node)) {
|
||||||
|
if (root.right != null) {
|
||||||
|
Node min = findMin(root.right, root.getAxis());
|
||||||
|
root.point = min.point;
|
||||||
|
root.right = delete(root.right, min);
|
||||||
|
} else if (root.left != null) {
|
||||||
|
Node min = findMin(root.left, root.getAxis());
|
||||||
|
root.point = min.point;
|
||||||
|
root.left = delete(root.left, min);
|
||||||
|
} else return null;
|
||||||
|
}
|
||||||
|
if (root.getAxisCoordinate() < node.point.getCoordinate(root.getAxis())) root.left = delete(root.left, node);
|
||||||
|
else root.right = delete(root.right, node);
|
||||||
|
return root;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Finds the nearest point in the tree to the given point.
|
||||||
|
*
|
||||||
|
* @param point The point to find the nearest neighbor to.
|
||||||
|
* */
|
||||||
|
public Point findNearest(Point point) {
|
||||||
|
return findNearest(root, point, root).point;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Finds the nearest point in a subtree to the given point.
|
||||||
|
*
|
||||||
|
* @param root The root of the subtree to search in.
|
||||||
|
* @param point The point to find the nearest neighbor to.
|
||||||
|
* @param nearest The nearest neighbor found so far.
|
||||||
|
* */
|
||||||
|
private Node findNearest(Node root, Point point, Node nearest) {
|
||||||
|
if (root == null) return nearest;
|
||||||
|
if (root.point.equals(point)) return root;
|
||||||
|
int distance = Point.comparableDistance(root.point, point);
|
||||||
|
int distanceExceptAxis = Point.comparableDistanceExceptAxis(root.point, point, root.getAxis());
|
||||||
|
if (distance < Point.comparableDistance(nearest.point, point)) nearest = root;
|
||||||
|
nearest = findNearest(root.getNearChild(point), point, nearest);
|
||||||
|
if (distanceExceptAxis < Point.comparableDistance(nearest.point, point))
|
||||||
|
nearest = findNearest(root.getFarChild(point), point, nearest);
|
||||||
|
return nearest;
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,64 @@
|
|||||||
|
package com.thealgorithms.datastructures.trees;
|
||||||
|
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
|
||||||
|
public class KDTreeTest {
|
||||||
|
|
||||||
|
KDTree.Point pointOf(int x, int y) {
|
||||||
|
return new KDTree.Point(new int[]{x, y});
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void findMin() {
|
||||||
|
int[][] coordinates = {
|
||||||
|
{30, 40},
|
||||||
|
{5, 25},
|
||||||
|
{70, 70},
|
||||||
|
{10, 12},
|
||||||
|
{50, 30},
|
||||||
|
{35, 45}
|
||||||
|
};
|
||||||
|
KDTree kdTree = new KDTree(coordinates);
|
||||||
|
|
||||||
|
assertEquals(5, kdTree.findMin(0).getCoordinate(0));
|
||||||
|
assertEquals(12, kdTree.findMin(1).getCoordinate(1));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void delete() {
|
||||||
|
int[][] coordinates = {
|
||||||
|
{30, 40},
|
||||||
|
{5, 25},
|
||||||
|
{70, 70},
|
||||||
|
{10, 12},
|
||||||
|
{50, 30},
|
||||||
|
{35, 45}
|
||||||
|
};
|
||||||
|
KDTree kdTree = new KDTree(coordinates);
|
||||||
|
|
||||||
|
kdTree.delete(pointOf(30, 40));
|
||||||
|
assertEquals(35, kdTree.getRoot().getPoint().getCoordinate(0));
|
||||||
|
assertEquals(45, kdTree.getRoot().getPoint().getCoordinate(1));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void findNearest() {
|
||||||
|
int[][] coordinates = {
|
||||||
|
{2, 3},
|
||||||
|
{5, 4},
|
||||||
|
{9, 6},
|
||||||
|
{4, 7},
|
||||||
|
{8, 1},
|
||||||
|
{7, 2}
|
||||||
|
};
|
||||||
|
KDTree kdTree = new KDTree(coordinates);
|
||||||
|
|
||||||
|
assertEquals(pointOf(7, 2), kdTree.findNearest(pointOf(7, 2)));
|
||||||
|
assertEquals(pointOf(8, 1), kdTree.findNearest(pointOf(8, 1)));
|
||||||
|
assertEquals(pointOf(2, 3), kdTree.findNearest(pointOf(1, 1)));
|
||||||
|
assertEquals(pointOf(5, 4), kdTree.findNearest(pointOf(5, 5)));
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
Reference in New Issue
Block a user