K-近邻


题目要求

  1. 根据数据集构造 kd-tree
  2. 基于 kd-tree,对于给定的 x,输出其最近邻元素及其欧式距离
  3. 基于 kd-tree,对于给定的 x,和正整数 n,输出其 n 个最近邻元素列表及其距离值

基本原理

$K-NearestNeighbor$,每个样本点都可以用它最近的K个近邻值来代表。

又名,基于实例的学习

代码实现

# heapq 优先队列算法,是一个原生的 python list, 0 号元素总为最小的元素

from collections import namedtuple
from math import sqrt

#定义一个命名元组
result = namedtuple("Result", "nearest_point nearest_dist nodes_visited")
data = [[2,3],[5,4],[9,6],[4,7],[8,1],[7,2]]

class Node:
    def __init__(self, point, d, left, right, cnt):
        self.point = point
        self.d = d          #维度
        self.left = left
        self.right = right
        self.cnt = cnt     #节点以下的数目(包括该节点)

class KdTree:
    def __init__(self, data):
        if data:
            k = len(data[0])
        def create(d, data_set):                       
            if not data_set: return None
            data_set.sort(key = lambda x: x[d])
            pos = len(data_set)//2                      # 需要分开的位置,最后会向右取
            newd = (d+1)%2
            return Node(data_set[pos], d, create(newd, data_set[:pos]), create(newd, data_set[pos+1:]), len(data_set))

        self.root = create(0, data)

    def preorder(self):
        def fun(node):
            print(node.point)
            print(node.cnt)
            if node.left: fun(node.left)
            if node.right: fun(node.right)
        fun(self.root)

    def test(self, point):
        k = len(point)  # 数据维度

        def travel(node, target, max_dist):
            if node is None: return result([0] * k, float("inf"), 0)     #出口0

            nodes_visited = 1

            d = node.d           #比较的维度
            point = node.point 

            nearer_node = node.left if target[d] <= point[d] else node.right       #下一步走的两个点
            further_node = node.right if target[d] <= point[d] else node.left

            temp1 = travel(nearer_node, target, max_dist)  # 进行遍历找到包含目标点的区域

            nearest = temp1.nearest_point  # 以此叶结点作为“当前最近点”
            dist = temp1.nearest_dist      # 更新最近距离

            nodes_visited += temp1.nodes_visited

            if dist < max_dist:
                max_dist = dist  # 最近点将在以目标点为球心,max_dist为半径的超球体内

            if max_dist < abs(point[d] - target[d]):       #出口1,另一超矩形无用
                return result(nearest, dist, nodes_visited)  

            temp_dist = sqrt(sum((x - y)**2 for x,y in zip(point, target)))

            if temp_dist < dist:  # 如果“更近”
                nearest = point   # 更新最近点
                dist = temp_dist  # 更新最近距离
                max_dist = dist   # 更新超球体半径

            # 检查另一个子结点对应的区域是否有更近的点
            temp2 = travel(further_node, target, max_dist)

            nodes_visited += temp2.nodes_visited
            if temp2.nearest_dist < dist:  # 如果另一个子结点内存在更近距离
                nearest = temp2.nearest_point  # 更新最近点
                dist = temp2.nearest_dist  # 更新最近距离

            return result(nearest, dist, nodes_visited)
        return travel(self.root, point, float('inf'))

if __name__ == "__main__":
    kd = KdTree(data)
    #kd.preorder()
    print(kd.test([2,2]))
    print(kd.test([11,12]))


#第三问写不出来,想用heapq的

文章作者: ╯晓~
版权声明: 本博客所有文章除特別声明外,均采用 CC BY 4.0 许可协议。转载请注明来源 ╯晓~ !
评论
  目录