lru 算法的原理简而言之就是一个 hash ,一个 double linked list
- Linked List 提供 O(1) 的复杂度对元素进行插入和删除
- hash 提供 O(1) 的复杂度进行查找
本文主要是通过阅读一个 rust 实现的 lru 学习相关语法。
- 如何在结构体里面使用指针?
- rust 是否有 raw pointer 直接指向内存地址,如果能用该怎么用?
Linked List 节点结构体
上面提到,真正的 key/value
是存在双链表的 Node 里,所以需要先定义这个 Node 长什么样,lru-rs
中 LruEntry 表示的就是 node:
- K V 代表的是泛型的类型,
struct LruEntry<K, V> {
key: mem::MaybeUninit<K>,
val: mem::MaybeUninit<V>,
prev: *mut LruEntry<K, V>,
next: *mut LruEntry<K, V>,
}
下面是如何初始化一个 Node,
impl<K, V> LruEntry<K, V> {
fn new(key: K, val: V) -> Self {
LruEntry {
key: mem::MaybeUninit::new(key),
val: mem::MaybeUninit::new(val),
prev: ptr::null_mut(),
next: ptr::null_mut(),
}
}
fn new_sigil() -> Self {
LruEntry {
key: mem::MaybeUninit::uninit(),
val: mem::MaybeUninit::uninit(),
prev: ptr::null_mut(),
next: ptr::null_mut(),
}
}
}
- key value 用
mem::MaybeUninit::new(key)
进行初始化 - prev next 指针用
ptr::null_mut()
初始化
LRU cache 结构体
链表的 node 定义好以后,双链表结构也自然而然就有了。接下来还缺一个 map 结构体,这个可以用 rust 原生的 hash 函数库,然后就可以定义出 LRU 结构体
pub struct LruCache<K, V, S = DefaultHasher> {
map: HashMap<KeyRef<K>, Box<LruEntry<K, V>>, S>,
cap: usize,
// head and tail are sigil nodes to faciliate inserting entries
head: *mut LruEntry<K, V>,
tail: *mut LruEntry<K, V>,
}
- head tail 指针和 Node 里面的指针用同样的定义方法。
- map 用的是 rust 自带的 HashMap 结构,它需要输入一个 key 类型 K,一个 value 类型 V,和 hash 算法 S,不指明 S 的话会调用默认函数。
- K 类型 是 KeyRef,这是自定义的类型,下面将会介绍
- V 类型是用 Box 加上 LruEntry。Box是堆上分配的指针类型,称为“装箱”(boxed),其指针本身在栈,指向的数据在堆,见 参考资料1
新建一个 LRU Cache 的实现
pub fn new(cap: usize) -> LruCache<K, V> {
LruCache::construct(cap, HashMap::with_capacity(cap))
}
fn construct(cap: usize, map: HashMap<KeyRef<K>, Box<LruEntry<K, V>>, S>) -> LruCache<K, V, S> {
let cache = LruCache {
map,
cap,
head: Box::into_raw(Box::new(LruEntry::new_sigil())),
tail: Box::into_raw(Box::new(LruEntry::new_sigil())),
};
unsafe {
(*cache.head).next = cache.tail;
(*cache.tail).prev = cache.head;
}
cache
}
*const T和*mut T在 Rust 中被称为“裸指针”
,这里的星号不是解引用运算符,它是类型名称的一部分。 前者const 表示 “不可变”,不可变意味着指针解引用之后不能直接赋值,后者表示可以赋值。
raw 指针可以绕过 Rust 的安全保障,但随之而来的是需要像 c++ 那样手动管理内存,rust 不再自动清除这块内存,不移动所有权,不管理生命周期。
而且所有用到这个变量的地方都要加上 unsafe{}
字段,告诉编译器,程序员知道这段代码是 unsafe 的。
一个简单的例子
let x = 5;
let raw = &x as *const i32;
let points_at = unsafe { *raw };
println!("raw points at {}", points_at);
再看 构造 LRU cache 的代码:
- head tail 是原始的指针,因此在初始化时需要用
Box::into_raw
函数获得 raw 指针 - 双链表首尾相连时,需要把代码放在
unsafe
块。
KeyRef 结构
pub struct KeyRef<K> {
k: *const K,
}
impl<K: Hash> Hash for KeyRef<K> {
fn hash<H: Hasher>(&self, state: &mut H) {
unsafe { (*self.k).hash(state) }
}
}
注意这里 k 使用了 *const K
类型的 raw pointer,因为这里不需要修改指向内存的值,而上面的 head tail 是需要修改双链表的前后关系的,所以需要 mut 关键字
LRU cache 的 Put/Get 操作
put/get 是 LRU cache 最基本的操作,get 可以说是 put 的简化版,所以理解了 put 操作以后 get 自然而然就理解了。
put 函数的流程如下:
- 先查 hash map,如果有 key 就更新旧的 value,同时把 node 移到链表头,注意:因为 hash map里存的是指向 key 内存的指针,所以移动 node 的时候,hash map 的 key 指针仍然有效,因此不需要调整 hash map
- 如果没有,新建一个 node 并添加到链表头,如果容量已满,移除链表最后一个元素
- hash map, key 为指向 node.key 的指针,value 为 node 指针,添加到 hash map 里
所以,理解 lru-rs
put 函数的关键在于理解 rust code 是操作 raw pointer 的。
具体关于 rust raw pointer 的解引用比我想象的要复杂的多,比如下面这段代码就包含了 3 次解引用,绕来绕去一不小心就会出错。
let old_key = KeyRef {
k: unsafe { &(*(*(*self.tail).prev).key.as_ptr()) },
};
所以,初步使用体验下来,感觉 rust 的代码,设计指针操作的还是很麻烦的,甚至比 c/c++ 还要负杂。
put/get 操作的具体实现就不一一解释了,直接看源码会更清楚 https://github.com/jeromefroe/lru-rs.git
贴一个我实现的简化版本
use std::collections::HashMap;
use core::ptr;
struct Node {
key: i32,
value: i32,
prev: *mut Node,
next: *mut Node,
}
impl Node {
fn new(key: i32, val: i32) -> Self {
Node {
key,
value: val,
prev: ptr::null_mut(),
next: ptr::null_mut(),
}
}
fn new_null() -> Self {
Node {
key: -1,
value: -1,
prev: ptr::null_mut(),
next: ptr::null_mut(),
}
}
}
struct LRUCache {
map: HashMap<i32, Box<Node>>,
cap: i32,
head: *mut Node,
tail: *mut Node,
}
impl LRUCache {
fn new(capacity: i32) -> Self {
let cache = LRUCache{
map: HashMap::new(),
cap: capacity,
head: Box::into_raw(Box::new(Node::new_null())),
tail: Box::into_raw(Box::new(Node::new_null())),
};
unsafe {
(*cache.head).next = cache.tail;
(*cache.tail).prev = cache.head;
}
cache
}
fn get(&mut self, key: i32) -> i32 {
if let Some(v) = self.map.get_mut(&key) {
let node_ptr: *mut Node = &mut **v;
let value = (*(*v)).value;
self.detach(node_ptr);
self.push_front(node_ptr);
return value
} else {
return -1
}
}
fn put(&mut self, key: i32, value: i32) {
let node_ptr = self.map.get_mut(&key).map(|node| {
let node_ptr : *mut Node = &mut **node;
node_ptr
});
match node_ptr {
Some(node_ptr) => {
unsafe {
(*node_ptr).value = value;
}
self.detach(node_ptr);
self.push_front(node_ptr);
}
None => {
if self.cap == 0 {
return
}
let mut node: Box<Node> = Box::new(Node::new(key, value));
if self.len() == self.cap() {
let oldest_key: i32;
unsafe {
oldest_key = (*(*(self.tail)).prev).key;
};
let oldest_node: *mut Node = self.map.get_mut(&oldest_key).map(|node| {
let node_ptr : *mut Node = &mut **node;
node_ptr
}).unwrap();
self.detach(oldest_node);
// remove key from hash map
self.map.remove(&oldest_key).unwrap();
}
let node_ptr: *mut Node = &mut *node;
self.push_front(node_ptr);
self.map.insert(key, node);
}
}
}
fn len(&self) -> usize {
self.map.len()
}
fn cap(&self) -> usize {
self.cap as usize
}
fn push_front(&mut self, n: *mut Node) {
unsafe {
(*n).next = (*self.head).next;
(*n).prev = self.head;
(*self.head).next = n;
(*(*n).next).prev = n;
}
}
fn detach(&mut self, n: *mut Node) {
unsafe {
(*(*n).prev).next = (*n).next;
(*(*n).next).prev = (*n).prev;
}
}
fn print_all(&self) {
let mut head = self.head;
let tail = self.tail;
while head != tail {
let cur: *mut Node;
unsafe {
cur = (*head).next;
println!("key = {}, value = {}", (*cur).key, (*cur).value);
head = cur;
}
}
}
}
fn main() {
let mut lru = LRUCache::new(2);
lru.put(1, 1);
lru.put(2, 2);
let v = lru.get(1);
println!("key = {}, val = {}", 1, v);
lru.put(3, 3);
let v = lru.get(2);
println!("key = {}, val = {}", 2, v);
lru.put(4, 4);
let v = lru.get(1);
println!("key = {}, val = {}", 1, v);
let v = lru.get(3);
println!("key = {}, val = {}", 3, v);
let v = lru.get(4);
println!("key = {}, val = {}", 4, v);
}