Rust's concurrency and kv server network processing and network security parts

Understanding concurrency and parallelism
One of the founders of Golang has a very insightful and intuitive explanation: Concurrency is the ability to process many things at the same time, and parallelism is a means of executing many things at the same time.
We process the things we need to do in multiple threads or multiple asynchronous tasks. This is the ability of concurrency. Running these threads or asynchronous tasks simultaneously on a multi-core and multi-CPU machine is a parallel method. It can be said that concurrency empowers parallelism. When we have the capability of concurrency, parallelism is a matter of course.

In the process of dealing with concurrency, the difficulty is not how to create multiple threads to distribute work, but how to synchronize these concurrent tasks . Let’s take a look at several common working modes in the concurrent state: free competition mode, map/reduce mode, and DAG mode:
Insert image description here
map/reduce mode breaks up the work, and after the same processing is completed, the results are organized in a certain order. Get up; DAG mode, divide the work into disjoint and dependent sub-tasks, and then execute them concurrently according to the dependencies. (Here you can contact how concurrency is handled in C++ concurrent programming)
Behind these concurrency modes, what concurrency primitives can we use? These two lectures will focus on explaining and in-depth five concepts: Atomic, Mutex, Condvar, Channel and Actor models. Today we will talk about the first two Atomic and Mutex.

Atomic is the foundation of all concurrency primitives , laying a solid foundation for the synchronization of concurrent tasks. Behind it is the CAS principle:
the most basic guarantee is: you can read a memory address through an instruction to determine whether its value is equal to a certain prefix value. If it is equal, modify it to a new value . This is the Compare-and-swap operation, or CAS for short. It is the cornerstone of almost all concurrency primitives of the operating system, allowing us to implement a lock that works properly. compare_exchange is a CAS operation provided by Rust, which will be compiled into the corresponding CAS instruction of the CPU.
But for 3 and 4 related to compiler/CPU automatic optimization, we still need some additional processing. These are the two additional strange parameters related to Ordering in this function. This can also be related to statement sorting in C++.


pub enum Ordering {
    
    
    Relaxed,
    Release,
    Acquire,
    AcqRel,
    SeqCst,
}

What I personally use most is to create various lock-free data structures. For example, a global ID generator is needed. Of course you can use a module like UUID to generate a unique ID, but if we also need the ID to be ordered, then AtomicUsize is the best choice.

Mutex (mutex and spin lock) can also contact mutex and condition variables to achieve synchronization mechanism
Atomic. Although Atomic can handle the demand for locking in free competition mode**, it is not that convenient to use after all. We need a higher level The concurrency primitive** ensures that the software system controls multiple threads' access to the same shared resource, so that each thread can have exclusive or mutual exclusive access when accessing the shared resource.

SpinLock, as the name suggests, is a thread that waits for a lock available in a critical section through CPU idle (spin, just like the previous while loop) busy wait (busy wait). However, this implementation of mutual exclusion through SpinLock has limitations in usage scenarios: if the protected critical section is too large, the overall performance will drop sharply, the CPU will be busy waiting, wasting resources and not doing practical work , and is not suitable as a A common processing method.
Mutex lock: With Mutex lock, the thread will be scheduled out when waiting for the lock, and then scheduled back when the lock is available.
It sounds like SpinLock is very inefficient, but it is not. It depends on the size of the critical section of the lock. If there is very little code to be executed in the critical section, SpinLock is worthwhile compared to the context switch caused by Mutex lock. In the Linux Kernel, many times we can only use SpinLock.

Atomic / Mutex solves the synchronization problem of concurrent tasks in free competition mode, and can also well solve the synchronization problem in map/reduce mode, because synchronization only occurs in the map and reduce stages.
However, they do not solve a higher-level problem, which is the DAG pattern: What should we do if this kind of access needs to be carried out in a certain order or there are dependencies before and after?

The typical scenario of this problem is the producer-consumer model : after the producer produces the content, there needs to be a mechanism to notify the consumer that it can be consumed. For example, when there is data on the socket, the processing thread is notified to process the data. After the processing is completed, the socket sending and receiving thread is notified to send the data.

Condvar should be similar to the condition variable in C++. Note the comparison.
Therefore, the operating system also provides Condvar. Condvar has two states: wait: the thread waits in the queue until a certain condition is met. Notify: When the conditions of condvar are met, the current thread notifies other waiting threads that they can be awakened. Notifications can be a single notification, multiple notifications, or even broadcast (notifying everyone). In practice, Condvar is often used together with Mutex: Mutex is used to ensure that conditions are mutually exclusive during reading and writing, and Condvar is used to control the waiting and waking up of threads. us

Channel
, but it will be more difficult to use Mutex and Condvar to handle the complex DAG concurrency mode. Therefore, Rust also provides a variety of Channels for handling communication between concurrent tasks. Channel encapsulates the lock in a small area for queue writing and reading, and then completely separates readers and writers, allowing readers to read data and writers to write data. For developers, in addition to potential context switching, In addition, it has nothing to do with locks, just like accessing a local queue.
Compared with Mutex, Channel has the highest level of abstraction, the most intuitive interface, and the psychological burden of using it is not that great. When using Mutex, you need to be very careful to avoid deadlocks, control the size of critical sections, and prevent any possible accidents.

When implementing Channel, different tools will be selected according to different usage scenarios. Rust provides the following four channels:
oneshot: This is probably the simplest Channel. The writer only sends data once, and the reader only reads it once. This one-time, synchronization between multiple threads can be accomplished using a oneshot channel. Due to the special purpose of oneshot, atomic swap can be used directly to implement it.

bounded: bounded channel has a queue, but the queue has an upper limit. Once the queue is full, the writer also needs to be suspended and waiting. When blocking occurs, once the reader reads the data, the channel will internally use Condvar's notify_one to notify the writer and wake up a writer so that he can continue writing.

unbounded: The queue has no upper limit. If it is full, it will automatically expand. We know that many of Rust's data structures such as Vec and VecDeque are automatically expanded. Compared with bounded, except that it does not block writers, other implementations are very similar.

For all these channel types, the implementation ideas of synchronous and asynchronous are similar. The main difference lies in the objects that are suspended/awakened. In the synchronous world, the objects that are suspended/awakened are threads; in the asynchronous world, they are tasks with very small granularity.

Stage practice (4): Build a simple KV server-network processing
(for protobuf analysis, you can also contact the C++ project).
We have been using a mysterious async-prost library before, and we magically completed the packetization and unpacking of TCP frames. . The main idea is to add a header to provide the length of the frame when serializing data. When deserializing, read the header first, obtain the length, and then read the corresponding data. Our challenge today is to try to handle the packetization and unpacking logic by ourselves without relying on async-prost
based on the KV server we completed last time . If you master this ability and cooperate with protobuf, you can design any protocol that can carry actual business .

protobuf helps us solve the problem of how to define protocol messages, but how to distinguish one message from another message is a headache. We need to define appropriate delimiters. Delimiter + message data is a Frame
(many TCP-based protocols use \r\n as delimiter, such as FTP; some use message length as delimiter, such as gRPC; and some use a mixture of the two, such as Redis RESP; for more complex ones such as HTTP, \r\n is used to separate headers, \r\n\r\n is used between header / body, and the length of the body will be provided in the header. "\r\n" like this The delimiter is suitable for protocol messages that are ASCII data; and the separation by length is suitable for protocol messages that are binary data. The protobuf carried by our KV Server is binary, so we put a length before the payload as a frame separator. .)
Tokio has a tokio-util library, which has helped us deal with the main requirements for frame-related packet unpacking, including LinesDelimited (processing \r\n delimiters) and LengthDelimited (processing length delimiters) let mut stream =
Framed ::new(stream, LengthDelimitedCodec::new());

( Why do you need to design it yourself? Because the actual needs are changeable. It is not just the delimiter that determines the length. For example, you can also customize whether compression is needed? Is other special processing required? The use of library code is limited because its interface provides The function is fixed )
In order to be closer to reality, we take the highest bit of the 4-byte length as a signal of whether to compress or not. If it is set, it means that the subsequent payload is a gzip-compressed protobuf, otherwise it is directly protobuf: according to
Insert image description here
convention , let’s first define the trait that handles this logic:

pub trait FrameCoder
where
    Self: Message + Sized + Default,
{
    
    
    /// 把一个 Message encode 成一个 frame
    fn encode_frame(&self, buf: &mut BytesMut) -> Result<(), KvError>;
    /// 把一个完整的 frame decode 成一个 Message
    fn decode_frame(buf: &mut BytesMut) -> Result<Self, KvError>;
}

Implement traits


use std::io::{
    
    Read, Write};

use crate::{
    
    CommandRequest, CommandResponse, KvError};
use bytes::{
    
    Buf, BufMut, BytesMut};
use flate2::{
    
    read::GzDecoder, write::GzEncoder, Compression};
use prost::Message;
use tokio::io::{
    
    AsyncRead, AsyncReadExt};
use tracing::debug;

/// 长度整个占用 4 个字节
pub const LEN_LEN: usize = 4;
/// 长度占 31 bit,所以最大的 frame 是 2G
const MAX_FRAME: usize = 2 * 1024 * 1024 * 1024;
/// 如果 payload 超过了 1436 字节,就做压缩
const COMPRESSION_LIMIT: usize = 1436;
/// 代表压缩的 bit(整个长度 4 字节的最高位)
const COMPRESSION_BIT: usize = 1 << 31;

/// 处理 Frame 的 encode/decode
pub trait FrameCoder
where
    Self: Message + Sized + Default,
{
    
    
    /// 把一个 Message encode 成一个 frame
    fn encode_frame(&self, buf: &mut BytesMut) -> Result<(), KvError> {
    
    
        let size = self.encoded_len();

        if size >= MAX_FRAME {
    
    
            return Err(KvError::FrameError);
        }

        // 我们先写入长度,如果需要压缩,再重写压缩后的长度
        buf.put_u32(size as _);

        if size > COMPRESSION_LIMIT {
    
    
            let mut buf1 = Vec::with_capacity(size);
            self.encode(&mut buf1)?;

            // BytesMut 支持逻辑上的 split(之后还能 unsplit)
            // 所以我们先把长度这 4 字节拿走,清除
            let payload = buf.split_off(LEN_LEN);
            buf.clear();

            // 处理 gzip 压缩,具体可以参考 flate2 文档
            let mut encoder = GzEncoder::new(payload.writer(), Compression::default());
            encoder.write_all(&buf1[..])?;

            // 压缩完成后,从 gzip encoder 中把 BytesMut 再拿回来
            let payload = encoder.finish()?.into_inner();
            debug!("Encode a frame: size {}({})", size, payload.len());

            // 写入压缩后的长度
            buf.put_u32((payload.len() | COMPRESSION_BIT) as _);

            // 把 BytesMut 再合并回来
            buf.unsplit(payload);

            Ok(())
        } else {
    
    
            self.encode(buf)?;
            Ok(())
        }
    }

    /// 把一个完整的 frame decode 成一个 Message
    fn decode_frame(buf: &mut BytesMut) -> Result<Self, KvError> {
    
    
        // 先取 4 字节,从中拿出长度和 compression bit
        let header = buf.get_u32() as usize;
        let (len, compressed) = decode_header(header);
        debug!("Got a frame: msg len {}, compressed {}", len, compressed);

        if compressed {
    
    
            // 解压缩
            let mut decoder = GzDecoder::new(&buf[..len]);
            let mut buf1 = Vec::with_capacity(len * 2);
            decoder.read_to_end(&mut buf1)?;
            buf.advance(len);

            // decode 成相应的消息
            Ok(Self::decode(&buf1[..buf1.len()])?)
        } else {
    
    
            let msg = Self::decode(&buf[..len])?;
            buf.advance(len);
            Ok(msg)
        }
    }
}

impl FrameCoder for CommandRequest {
    
    }
impl FrameCoder for CommandResponse {
    
    }

fn decode_header(header: usize) -> (usize, bool) {
    
    
    let len = header & !COMPRESSION_BIT;
    let compressed = header & COMPRESSION_BIT == COMPRESSION_BIT;
    (len, compressed)
}

If you are wondering why COMPRESSION_LIMIT is set to 1436?
This is because the MTU of Ethernet is 1500. After excluding the 20 bytes of the IP header and the 20 bytes of the TCP header, there are still 1460. Generally, TCP packets will contain some Options (such as timestamp), and IP packets may also contain them, so we reserve 20 Bytes; minus the length of 4 bytes, it is 1436, the maximum message length without fragmentation. If it is larger than this, it is likely to cause fragmentation, so we will simply compress it.

Currently, this code does not touch anything related to socket IO, it is just pure logic. Next we need to connect it with the TcpStream we use to handle the server client. There is some processing in the middle that allows the stream to process the frame, which I won’t talk about yet.
The main purpose is to let the stream read the complete frame, which involves some library functions, so I won’t go into details.
stream.read_exact(&mut buf[LEN_LEN…]).await?;

Next, we need to think about how to encapsulate the server and client.
On the server side, use process to encapsulate


#[tokio::main]
async fn main() -> Result<()> {
    
    
    tracing_subscriber::fmt::init();
    let addr = "127.0.0.1:9527";
    let service: Service = ServiceInner::new(MemTable::new()).into();
    let listener = TcpListener::bind(addr).await?;
    info!("Start listening on {}", addr);
    loop {
    
    
        let (stream, addr) = listener.accept().await?;
        info!("Client {:?} connected", addr);
        let stream = ProstServerStream::new(stream, service.clone());
        tokio::spawn(async move {
    
     stream.process().await });
    }
}

This process() method is actually an encapsulation of the while loop in tokio::spawn in examples/server.rs:


while let Some(Ok(cmd)) = stream.next().await {
    
    
    info!("Got a new command: {:?}", cmd);
    let res = svc.execute(cmd);
    stream.send(res).await.unwrap();
}

For the client, we also hope to be able to directly execute() a command and get the result:


#[tokio::main]
async fn main() -> Result<()> {
    
    
    tracing_subscriber::fmt::init();

    let addr = "127.0.0.1:9527";
    // 连接服务器
    let stream = TcpStream::connect(addr).await?;

    let mut client = ProstClientStream::new(stream);

    // 生成一个 HSET 命令
    let cmd = CommandRequest::new_hset("table1", "hello", "world".to_string().into());

    // 发送 HSET 命令
    let data = client.execute(cmd).await?;
    info!("Got response {:?}", data);

    Ok(())
}

This execute() is actually an encapsulation of the sending and receiving code in examples/client.rs:


client.send(cmd).await?;
if let Some(Ok(data)) = client.next().await {
    
    
    info!("Got response {:?}", data);
}

Okay, let's first look at the data structure of the server processing a TcpStream. It needs to contain TcpStream, as well as the Service we created before to process client commands. Therefore, the structure for the server to process TcpStream contains these two parts:


pub struct ProstServerStream<S> {
    
    
    inner: S,
    service: Service,
}

The client's structure for processing TcpStream only needs to contain TcpStream:


pub struct ProstClientStream<S> {
    
    
    inner: S,
}

Here, the generic parameter S is still used. **In the future, if we want to support WebSocket, or support TLS on top of TCP, it will save us the need to change this layer of code. **This also reflects the benefits of generic parameters, the same as the previous store trait.

The next step is to implement process and execute specifically.


mod frame;
use bytes::BytesMut;
pub use frame::{
    
    read_frame, FrameCoder};
use tokio::io::{
    
    AsyncRead, AsyncWrite, AsyncWriteExt};
use tracing::info;

use crate::{
    
    CommandRequest, CommandResponse, KvError, Service};

/// 处理服务器端的某个 accept 下来的 socket 的读写
pub struct ProstServerStream<S> {
    
    
    inner: S,
    service: Service,
}

/// 处理客户端 socket 的读写
pub struct ProstClientStream<S> {
    
    
    inner: S,
}

impl<S> ProstServerStream<S>
where
    S: AsyncRead + AsyncWrite + Unpin + Send,
{
    
    
    pub fn new(stream: S, service: Service) -> Self {
    
    
        Self {
    
    
            inner: stream,
            service,
        }
    }

    pub async fn process(mut self) -> Result<(), KvError> {
    
    
        while let Ok(cmd) = self.recv().await {
    
    
            info!("Got a new command: {:?}", cmd);
            let res = self.service.execute(cmd);
            self.send(res).await?;
        }
        // info!("Client {:?} disconnected", self.addr);
        Ok(())
    }

    async fn send(&mut self, msg: CommandResponse) -> Result<(), KvError> {
        let mut buf = BytesMut::new();
        msg.encode_frame(&mut buf)?;
        let encoded = buf.freeze();
        self.inner.write_all(&encoded[..]).await?;
        Ok(())
    }

    async fn recv(&mut self) -> Result<CommandRequest, KvError> {
    
    
        let mut buf = BytesMut::new();
        let stream = &mut self.inner;
        read_frame(stream, &mut buf).await?;
        CommandRequest::decode_frame(&mut buf)
    }
}

impl<S> ProstClientStream<S>
where
    S: AsyncRead + AsyncWrite + Unpin + Send,
{
    
    
    pub fn new(stream: S) -> Self {
    
    
        Self {
    
     inner: stream }
    }

    pub async fn execute(&mut self, cmd: CommandRequest) -> Result<CommandResponse, KvError> {
    
    
        self.send(cmd).await?;
        Ok(self.recv().await?)
    }

    async fn send(&mut self, msg: CommandRequest) -> Result<(), KvError> {
    
    
        let mut buf = BytesMut::new();
        msg.encode_frame(&mut buf)?;
        let encoded = buf.freeze();
        self.inner.write_all(&encoded[..]).await?;
        Ok(())
    }

    async fn recv(&mut self) -> Result<CommandResponse, KvError> {
    
    
        let mut buf = BytesMut::new();
        let stream = &mut self.inner;
        read_frame(stream, &mut buf).await?;
        CommandResponse::decode_frame(&mut buf)
    }
}

After writing it, I found that the server and client codes are more concise. Process is used to replace the server processing process, and execute is used to replace the client execution command process. And a custom frame processing method is also used . This is the improvement of this section. The custom stream uses a generic parameter S, which eliminates the need to modify the code when adding new protocol types in the future.

Stage Practical Operation (5): Build a simple KV server-Network Security
So, when our application architecture is based on TCP, how to use TLS to ensure the security between the client and the server?
To use TLS, we first need an x509 certificate. TLS requires an x509 certificate for the client to verify that the server is a trusted server, and even for the server to verify the client to confirm that the other party is a trusted client.
For the convenience of testing, we must have the ability to generate our own CA certificate, server certificate, and even client certificate. The details of certificate generation will not be introduced in detail today. I have made a library called certify before, which can be used to generate various certificates. We can add this library to Cargo.toml:

[dev-dependencies]

certify = “0.3”

Then create a fixtures directory in the root directory to store the certificate, create the examples/gen_cert.rs file, and add the following code:


use anyhow::Result;
use certify::{
    
    generate_ca, generate_cert, load_ca, CertType, CA};
use tokio::fs;

struct CertPem {
    
    
    cert_type: CertType,
    cert: String,
    key: String,
}

#[tokio::main]
async fn main() -> Result<()> {
    
    
    let pem = create_ca()?;
    gen_files(&pem).await?;
    let ca = load_ca(&pem.cert, &pem.key)?;
    let pem = create_cert(&ca, &["kvserver.acme.inc"], "Acme KV server", false)?;
    gen_files(&pem).await?;
    let pem = create_cert(&ca, &[], "awesome-device-id", true)?;
    gen_files(&pem).await?;
    Ok(())
}

fn create_ca() -> Result<CertPem> {
    let (cert, key) = generate_ca(
        &["acme.inc"],
        "CN",
        "Acme Inc.",
        "Acme CA",
        None,
        Some(10 * 365),
    )?;
    Ok(CertPem {
        cert_type: CertType::CA,
        cert,
        key,
    })
}

fn create_cert(ca: &CA, domains: &[&str], cn: &str, is_client: bool) -> Result<CertPem> {
    let (days, cert_type) = if is_client {
        (Some(365), CertType::Client)
    } else {
        (Some(5 * 365), CertType::Server)
    };
    let (cert, key) = generate_cert(ca, domains, "CN", "Acme Inc.", cn, None, is_client, days)?;

    Ok(CertPem {
        cert_type,
        cert,
        key,
    })
}

async fn gen_files(pem: &CertPem) -> Result<()> {
    let name = match pem.cert_type {
        CertType::Client => "client",
        CertType::Server => "server",
        CertType::CA => "ca",
    };
    fs::write(format!("fixtures/{}.cert", name), pem.cert.as_bytes()).await?;
    fs::write(format!("fixtures/{}.key", name), pem.key.as_bytes()).await?;
    Ok(())
}

This code is very simple. It first generates a CA certificate, and then generates server and client certificates, all of which are stored in the newly created fixtures directory . You need to cargo run --examples gen_cert to run this command. We will use these certificates and keys in the test later.

The specific details about TLS will not be elaborated.
For KV server, after using TLS, the data encapsulation of the entire protocol is as shown in the figure below:
Insert image description here
It is estimated that many people get numb when they hear TLS or SSL, because they have had many bad experiences with openssl before. The code base of openssl is too complex, the API is not friendly, and compilation and linking are very difficult. However, the experience of using TLS under Rust is still very good. Rust has a very good encapsulation of openssl, and there are also rustls written in Rust that do not rely on openssl. Tokio further provides TLS support in line with the Tokio ecosystem, with openssl and rustls versions available.
Today we will use tokio-rustls to write TLS support. I believe you can see during the implementation process how easy it is to add the TLS protocol to the application to protect the network layer.
First add tokio-rustls in Cargo.toml:
then create src/network/tls.rs and write the following code (remember to introduce this file in src/network/mod.rs):


use std::io::Cursor;
use std::sync::Arc;

use tokio::io::{
    
    AsyncRead, AsyncWrite};
use tokio_rustls::rustls::{
    
    internal::pemfile, Certificate, ClientConfig, ServerConfig};
use tokio_rustls::rustls::{
    
    AllowAnyAuthenticatedClient, NoClientAuth, PrivateKey, RootCertStore};
use tokio_rustls::webpki::DNSNameRef;
use tokio_rustls::TlsConnector;
use tokio_rustls::{
    
    
    client::TlsStream as ClientTlsStream, server::TlsStream as ServerTlsStream, TlsAcceptor,
};

use crate::KvError;

/// KV Server 自己的 ALPN (Application-Layer Protocol Negotiation)
const ALPN_KV: &str = "kv";

/// 存放 TLS ServerConfig 并提供方法 accept 把底层的协议转换成 TLS
#[derive(Clone)]
pub struct TlsServerAcceptor {
    
    
    inner: Arc<ServerConfig>,
}

/// 存放 TLS Client 并提供方法 connect 把底层的协议转换成 TLS
#[derive(Clone)]
pub struct TlsClientConnector {
    
    
    pub config: Arc<ClientConfig>,
    pub domain: Arc<String>,
}

impl TlsClientConnector {
    
    
    /// 加载 client cert / CA cert,生成 ClientConfig
    pub fn new(
        domain: impl Into<String>,
        identity: Option<(&str, &str)>,
        server_ca: Option<&str>,
    ) -> Result<Self, KvError> {
    
    
        let mut config = ClientConfig::new();

        // 如果有客户端证书,加载之
        if let Some((cert, key)) = identity {
    
    
            let certs = load_certs(cert)?;
            let key = load_key(key)?;
            config.set_single_client_cert(certs, key)?;
        }

        // 加载本地信任的根证书链
        config.root_store = match rustls_native_certs::load_native_certs() {
    
    
            Ok(store) | Err((Some(store), _)) => store,
            Err((None, error)) => return Err(error.into()),
        };

        // 如果有签署服务器的 CA 证书,则加载它,这样服务器证书不在根证书链
        // 但是这个 CA 证书能验证它,也可以
        if let Some(cert) = server_ca {
    
    
            let mut buf = Cursor::new(cert);
            config.root_store.add_pem_file(&mut buf).unwrap();
        }

        Ok(Self {
    
    
            config: Arc::new(config),
            domain: Arc::new(domain.into()),
        })
    }

    /// 触发 TLS 协议,把底层的 stream 转换成 TLS stream
    pub async fn connect<S>(&self, stream: S) -> Result<ClientTlsStream<S>, KvError>
    where
        S: AsyncRead + AsyncWrite + Unpin + Send,
    {
    
    
        let dns = DNSNameRef::try_from_ascii_str(self.domain.as_str())
            .map_err(|_| KvError::Internal("Invalid DNS name".into()))?;

        let stream = TlsConnector::from(self.config.clone())
            .connect(dns, stream)
            .await?;

        Ok(stream)
    }
}

impl TlsServerAcceptor {
    
    
    /// 加载 server cert / CA cert,生成 ServerConfig
    pub fn new(cert: &str, key: &str, client_ca: Option<&str>) -> Result<Self, KvError> {
    
    
        let certs = load_certs(cert)?;
        let key = load_key(key)?;

        let mut config = match client_ca {
    
    
            None => ServerConfig::new(NoClientAuth::new()),
            Some(cert) => {
    
    
                // 如果客户端证书是某个 CA 证书签发的,则把这个 CA 证书加载到信任链中
                let mut cert = Cursor::new(cert);
                let mut client_root_cert_store = RootCertStore::empty();
                client_root_cert_store
                    .add_pem_file(&mut cert)
                    .map_err(|_| KvError::CertifcateParseError("CA", "cert"))?;

                let client_auth = AllowAnyAuthenticatedClient::new(client_root_cert_store);
                ServerConfig::new(client_auth)
            }
        };

        // 加载服务器证书
        config
            .set_single_cert(certs, key)
            .map_err(|_| KvError::CertifcateParseError("server", "cert"))?;
        config.set_protocols(&[Vec::from(&ALPN_KV[..])]);

        Ok(Self {
    
    
            inner: Arc::new(config),
        })
    }

    /// 触发 TLS 协议,把底层的 stream 转换成 TLS stream
    pub async fn accept<S>(&self, stream: S) -> Result<ServerTlsStream<S>, KvError>
    where
        S: AsyncRead + AsyncWrite + Unpin + Send,
    {
    
    
        let acceptor = TlsAcceptor::from(self.inner.clone());
        Ok(acceptor.accept(stream).await?)
    }
}

fn load_certs(cert: &str) -> Result<Vec<Certificate>, KvError> {
    
    
    let mut cert = Cursor::new(cert);
    pemfile::certs(&mut cert).map_err(|_| KvError::CertifcateParseError("server", "cert"))
}

fn load_key(key: &str) -> Result<PrivateKey, KvError> {
    
    
    let mut cursor = Cursor::new(key);

    // 先尝试用 PKCS8 加载私钥
    if let Ok(mut keys) = pemfile::pkcs8_private_keys(&mut cursor) {
    
    
        if !keys.is_empty() {
    
    
            return Ok(keys.remove(0));
        }
    }

    // 再尝试加载 RSA key
    cursor.set_position(0);
    if let Ok(mut keys) = pemfile::rsa_private_keys(&mut cursor) {
    
    
        if !keys.is_empty() {
    
    
            return Ok(keys.remove(0));
        }
    }

    // 不支持的私钥类型
    Err(KvError::CertifcateParseError("private", "key"))
}

Although it has more than 100 lines, the main job is actually to generate the ServerConfig / ClientConfig required by tokio-tls based on the provided certificate. After processing the config, the core logic of this code is actually the client's connect() method and the server's accept() method. They both accept a stream that satisfies AsyncRead + AsyncWrite + Unpin + Send. Similar to the previous lecture, we do not want the TLS code to only accept TcpStream, so a generic parameter S is provided here:
After using TlsConnector or TlsAcceptor to process connect/accept, we get a TlsStream, which also satisfies AsyncRead + AsyncWrite + Unpin + Send, subsequent operations can be completed on it.

Due to our good interface design along the way, especially ProstClientStream / ProstServerStream all accept generic parameters, TLS code can be seamlessly embedded. For example, client:


// 新加的代码
let connector = TlsClientConnector::new("kvserver.acme.inc", None, Some(ca_cert))?;

let stream = TcpStream::connect(addr).await?;

// 新加的代码
let stream = connector.connect(stream).await?;

let mut client = ProstClientStream::new(stream);

Just change the stream passed to ProstClientStream from TcpStream to the generated TlsStream to seamlessly support TLS.

Complete server side


use anyhow::Result;
use kv3::{
    
    MemTable, ProstServerStream, Service, ServiceInner, TlsServerAcceptor};
use tokio::net::TcpListener;
use tracing::info;

#[tokio::main]
async fn main() -> Result<()> {
    
    
    tracing_subscriber::fmt::init();
    let addr = "127.0.0.1:9527";

    // 以后从配置文件取
    let server_cert = include_str!("../fixtures/server.cert");
    let server_key = include_str!("../fixtures/server.key");

    let acceptor = TlsServerAcceptor::new(server_cert, server_key, None)?;
    let service: Service = ServiceInner::new(MemTable::new()).into();
    let listener = TcpListener::bind(addr).await?;
    info!("Start listening on {}", addr);
    loop {
    
    
        let tls = acceptor.clone();
        let (stream, addr) = listener.accept().await?;
        info!("Client {:?} connected", addr);
        let stream = tls.accept(stream).await?;
        let stream = ProstServerStream::new(stream, service.clone());
        tokio::spawn(async move {
    
     stream.process().await });
    }
}

client


use anyhow::Result;
use kv3::{
    
    CommandRequest, ProstClientStream, TlsClientConnector};
use tokio::net::TcpStream;
use tracing::info;

#[tokio::main]
async fn main() -> Result<()> {
    
    
    tracing_subscriber::fmt::init();

    // 以后用配置替换
    let ca_cert = include_str!("../fixtures/ca.cert");

    let addr = "127.0.0.1:9527";
    // 连接服务器
    let connector = TlsClientConnector::new("kvserver.acme.inc", None, Some(ca_cert))?;
    let stream = TcpStream::connect(addr).await?;
    let stream = connector.connect(stream).await?;

    let mut client = ProstClientStream::new(stream);

    // 生成一个 HSET 命令
    let cmd = CommandRequest::new_hset("table1", "hello", "world".to_string().into());

    // 发送 HSET 命令
    let data = client.execute(cmd).await?;
    info!("Got response {:?}", data);

    Ok(())
}

Compared with the code project in the previous lecture, the updated client and server codes only have one more line each, encapsulating TcpStream into TlsStream . This is the great power of using traits for interface-oriented programming. Various components of the system can come from different crates, but as long as their interfaces are consistent (or we create an adapter to make their interfaces consistent), they can be seamlessly inserted.

Guess you like

Origin blog.csdn.net/weixin_53344209/article/details/130132729