lru 算法的原理简而言之就是一个 hash ,一个 double linked list

  • Linked List 提供 O(1) 的复杂度对元素进行插入和删除
  • hash 提供 O(1) 的复杂度进行查找

本文主要是通过阅读一个 rust 实现的 lru 学习相关语法。

  • 如何在结构体里面使用指针?
  • rust 是否有 raw pointer 直接指向内存地址,如果能用该怎么用?

Linked List 节点结构体

上面提到,真正的 key/value 是存在双链表的 Node 里,所以需要先定义这个 Node 长什么样,lru-rs 中 LruEntry 表示的就是 node:

  • K V 代表的是泛型的类型,
struct LruEntry<K, V> {
    key: mem::MaybeUninit<K>,
    val: mem::MaybeUninit<V>,
    prev: *mut LruEntry<K, V>,
    next: *mut LruEntry<K, V>,
}

下面是如何初始化一个 Node,

impl<K, V> LruEntry<K, V> {
    fn new(key: K, val: V) -> Self {
        LruEntry {
            key: mem::MaybeUninit::new(key),
            val: mem::MaybeUninit::new(val),
            prev: ptr::null_mut(),
            next: ptr::null_mut(),
        }
    }

    fn new_sigil() -> Self {
        LruEntry {
            key: mem::MaybeUninit::uninit(),
            val: mem::MaybeUninit::uninit(),
            prev: ptr::null_mut(),
            next: ptr::null_mut(),
        }
    }
}
  • key value 用 mem::MaybeUninit::new(key)进行初始化
  • prev next 指针用 ptr::null_mut() 初始化

LRU cache 结构体

链表的 node 定义好以后,双链表结构也自然而然就有了。接下来还缺一个 map 结构体,这个可以用 rust 原生的 hash 函数库,然后就可以定义出 LRU 结构体

pub struct LruCache<K, V, S = DefaultHasher> {
    map: HashMap<KeyRef<K>, Box<LruEntry<K, V>>, S>,
    cap: usize,

    // head and tail are sigil nodes to faciliate inserting entries
    head: *mut LruEntry<K, V>,
    tail: *mut LruEntry<K, V>,
}
  • head tail 指针和 Node 里面的指针用同样的定义方法。
  • map 用的是 rust 自带的 HashMap 结构,它需要输入一个 key 类型 K,一个 value 类型 V,和 hash 算法 S,不指明 S 的话会调用默认函数。
  • K 类型 是 KeyRef,这是自定义的类型,下面将会介绍
  • V 类型是用 Box 加上 LruEntry。Box是堆上分配的指针类型,称为“装箱”(boxed),其指针本身在栈,指向的数据在堆,见 参考资料1

新建一个 LRU Cache 的实现

pub fn new(cap: usize) -> LruCache<K, V> {
        LruCache::construct(cap, HashMap::with_capacity(cap))
}

fn construct(cap: usize, map: HashMap<KeyRef<K>, Box<LruEntry<K, V>>, S>) -> LruCache<K, V, S> {
        let cache = LruCache {
            map,
            cap,
            head: Box::into_raw(Box::new(LruEntry::new_sigil())),
            tail: Box::into_raw(Box::new(LruEntry::new_sigil())),
        };

        unsafe {
            (*cache.head).next = cache.tail;
            (*cache.tail).prev = cache.head;
        }

        cache
    }

*const T和*mut T在 Rust 中被称为“裸指针” ,这里的星号不是解引用运算符,它是类型名称的一部分。 前者const 表示 “不可变”,不可变意味着指针解引用之后不能直接赋值,后者表示可以赋值。

raw 指针可以绕过 Rust 的安全保障,但随之而来的是需要像 c++ 那样手动管理内存,rust 不再自动清除这块内存,不移动所有权,不管理生命周期。

而且所有用到这个变量的地方都要加上 unsafe{} 字段,告诉编译器,程序员知道这段代码是 unsafe 的。

一个简单的例子

let x = 5;
let raw = &x as *const i32;

let points_at = unsafe { *raw };

println!("raw points at {}", points_at);

再看 构造 LRU cache 的代码:

  • head tail 是原始的指针,因此在初始化时需要用 Box::into_raw函数获得 raw 指针
  • 双链表首尾相连时,需要把代码放在 unsafe 块。

KeyRef 结构

pub struct KeyRef<K> {
    k: *const K,
} 

impl<K: Hash> Hash for KeyRef<K> {
    fn hash<H: Hasher>(&self, state: &mut H) {
        unsafe { (*self.k).hash(state) }
    }
}

注意这里 k 使用了 *const K类型的 raw pointer,因为这里不需要修改指向内存的值,而上面的 head tail 是需要修改双链表的前后关系的,所以需要 mut 关键字

LRU cache 的 Put/Get 操作

put/get 是 LRU cache 最基本的操作,get 可以说是 put 的简化版,所以理解了 put 操作以后 get 自然而然就理解了。

put 函数的流程如下:

  1. 先查 hash map,如果有 key 就更新旧的 value,同时把 node 移到链表头,注意:因为 hash map里存的是指向 key 内存的指针,所以移动 node 的时候,hash map 的 key 指针仍然有效,因此不需要调整 hash map
  2. 如果没有,新建一个 node 并添加到链表头,如果容量已满,移除链表最后一个元素
  3. hash map, key 为指向 node.key 的指针,value 为 node 指针,添加到 hash map 里

所以,理解 lru-rs put 函数的关键在于理解 rust code 是操作 raw pointer 的。

具体关于 rust raw pointer 的解引用比我想象的要复杂的多,比如下面这段代码就包含了 3 次解引用,绕来绕去一不小心就会出错。

let old_key = KeyRef {
    k: unsafe { &(*(*(*self.tail).prev).key.as_ptr()) },
};

所以,初步使用体验下来,感觉 rust 的代码,设计指针操作的还是很麻烦的,甚至比 c/c++ 还要负杂。

put/get 操作的具体实现就不一一解释了,直接看源码会更清楚 https://github.com/jeromefroe/lru-rs.git

贴一个我实现的简化版本

use std::collections::HashMap;
use core::ptr;

struct Node {
    key: i32,
    value: i32,
    prev: *mut Node,
    next: *mut Node,
}

impl Node {
    fn new(key: i32, val: i32) -> Self {
        Node {
            key,
            value: val,
            prev: ptr::null_mut(),
            next: ptr::null_mut(),
        }
    }

    fn new_null() -> Self {
        Node {
            key: -1,
            value: -1,
            prev: ptr::null_mut(),
            next: ptr::null_mut(),
        }
    }

}

struct LRUCache {
    map: HashMap<i32, Box<Node>>,
    cap: i32,

    head: *mut Node,
    tail: *mut Node,
}

impl LRUCache {

    fn new(capacity: i32) -> Self {
        let cache = LRUCache{
            map: HashMap::new(),
            cap: capacity,
            head: Box::into_raw(Box::new(Node::new_null())),
            tail: Box::into_raw(Box::new(Node::new_null())),
        };
        unsafe {
            (*cache.head).next = cache.tail;
            (*cache.tail).prev = cache.head;
        }
        cache
    }

    fn get(&mut self, key: i32) -> i32 {
        if let Some(v) = self.map.get_mut(&key) {
            let node_ptr: *mut Node = &mut **v;
            let value = (*(*v)).value;

            self.detach(node_ptr);
            self.push_front(node_ptr);

            return value
        } else {
            return -1
        }
    }

    fn put(&mut self, key: i32, value: i32) {
        let node_ptr = self.map.get_mut(&key).map(|node| {
            let node_ptr : *mut Node = &mut **node;
            node_ptr
        });

        match node_ptr {
            Some(node_ptr) => {
                unsafe {
                    (*node_ptr).value = value;
                }
                self.detach(node_ptr);
                self.push_front(node_ptr);
            }
            None => {
                if self.cap == 0 {
                    return
                }

                let mut node: Box<Node> = Box::new(Node::new(key, value));

                if self.len() == self.cap() {
                    let oldest_key: i32;
                    unsafe {
                        oldest_key = (*(*(self.tail)).prev).key;
                    };

                    let oldest_node: *mut Node = self.map.get_mut(&oldest_key).map(|node| {
                        let node_ptr : *mut Node = &mut **node;
                        node_ptr
                    }).unwrap();

                    self.detach(oldest_node);

                    // remove key from hash map
                    self.map.remove(&oldest_key).unwrap();
                }

                let node_ptr: *mut Node = &mut *node;
                self.push_front(node_ptr);
                self.map.insert(key, node);
            }
        }
    }

    fn len(&self) -> usize {
        self.map.len()
    }
    fn cap(&self) -> usize {
        self.cap as usize
    }

    fn push_front(&mut self, n: *mut Node) {
        unsafe {
            (*n).next = (*self.head).next;
            (*n).prev = self.head;
            (*self.head).next = n;
            (*(*n).next).prev = n;
        }
    }

    fn detach(&mut self, n: *mut Node) {
        unsafe {
            (*(*n).prev).next = (*n).next;
            (*(*n).next).prev = (*n).prev;
        }
    }

    fn print_all(&self) {
        let mut head = self.head;
        let tail = self.tail;

        while head != tail {
            let cur: *mut Node;
            unsafe {
                cur = (*head).next;

                println!("key = {}, value = {}", (*cur).key, (*cur).value);
                head = cur;
            }
        }
    }
}

fn main() {
    let mut lru = LRUCache::new(2);

    lru.put(1, 1);
    lru.put(2, 2);
    let v = lru.get(1);
    println!("key = {}, val = {}", 1, v);

    lru.put(3, 3);
    let v = lru.get(2);
    println!("key = {}, val = {}", 2, v);

    lru.put(4, 4);

    let v = lru.get(1);
    println!("key = {}, val = {}", 1, v);
    let v = lru.get(3);
    println!("key = {}, val = {}", 3, v);
    let v = lru.get(4);
    println!("key = {}, val = {}", 4, v);

}

参考资料

  1. https://rustcc.cn/article?id=76e5f3fb-20b9-48c9-8fc6-a0aad40ced8c