Golang网络编程实战

前言

网络编程是Golang的重要使用场景,它能以短小精悍的代码实现强大的功能,为你打造一把瑞士军刀。

特点: goroutine+阻塞通信的网络通信模型

Golang网络编程简述

目前主流TCP server一般均采用的都是”Non-Block + I/O多路复用”(有的也结合了多线程、多进程)。不过I/O多路复用也给使用者带来了不小的复杂度,以至于后续出现了许多高性能的I/O多路复用框架, 比如libevent、libev、libuv等,以帮助开发者简化开发复杂性,降低心智负担。不过Go的设计者似乎认为I/O多路复用的这种通过回调机制割裂控制流 的方式依旧复杂,且有悖于“一般逻辑”设计,为此Go语言将该“复杂性”隐藏在Runtime中了:Go开发者无需关注socket是否是 non-block的,也无需亲自注册文件描述符的回调,只需在每个连接对应的goroutine中以“block I/O”的方式对待socket处理即可。

基本API:

  • 创建监听服务器: net.Listen()
  • 接受连接: listen.Accept()
  • 创建连接: net.Dial()或net.DialTimeout()

获得连接conn后,即可以在上进行读写,以完成业务逻辑。Go runtime隐藏了I/O多路复用的复杂性。语言使用者只需采用goroutine+Block I/O的模式即可满足大部分场景需求。

//go-tcpsock/server.go
func HandleConn(conn net.Conn) {
    
    
    defer conn.Close()

    for {
    
    
        // read from the connection
        // ... ...
        // write to the connection
        //... ...
    }
}

func main() {
    
    
    listen, err := net.Listen("tcp", ":8888")
    if err != nil {
    
    
        fmt.Println("listen error: ", err)
        return
    }

    for {
    
    
        conn, err := listen.Accept()
        if err != nil {
    
    
            fmt.Println("accept error: ", err)
            break
        }

        // start a new goroutine to handle the new connection
        go HandleConn(conn)
    }
}

简易的HTTP服务器

package main
 
import (
    "fmt"
    "log"
    "net/http"
)
 
// w表示response对象,返回给客户端的内容都在对象里处理
// r表示客户端请求对象,包含了请求头,请求参数等等
func index(w http.ResponseWriter, r *http.Request) {
    
    
    // 往w里写入内容,就会在浏览器里输出
    fmt.Fprintf(w, "Hello golang http!")
}
 
func main() {
    
    
    // 设置路由,如果访问/,则调用index方法
    http.HandleFunc("/", index)
 
    // 启动web服务,监听9090端口
    err := http.ListenAndServe(":9090", nil)
    if err != nil {
    
    
        log.Fatal("ListenAndServe: ", err)
    }
}

最简单的聊天服务程序

服务端

package main

import (
	"bufio"
	"flag"
	"fmt"
	"io"
	"log"
	"net"
	"os"
)

var port int

func init() {
    
    
	flag.IntVar(&port, "port", 7575, "the tunnelthing port")
}

func main() {
    
    
	flag.Parse()

	listener, err := net.Listen("tcp", fmt.Sprintf(":%d", port))
	if err != nil {
    
    
		log.Fatal(err)
	}

	input := make(chan string)
	go getTerminalInput(input)

	go broadcaster(input)

	for {
    
    
		conn, err := listener.Accept()
		if err != nil {
    
    
			log.Print(err)
			continue
		}

		go handleConn(conn)
	}
}

// 对外发送消息的通道
type client chan<- string

var (
	entering = make(chan client)
	leaving  = make(chan client)
	// 所有接收的客户消息
	messages = make(chan string)
)

func getTerminalInput(input chan string) {
    
    
	scanner := bufio.NewScanner(os.Stdin)
	for {
    
    
		scanner.Scan()
		input <- scanner.Text()
	}
}

func broadcaster(input chan string) {
    
    

	// 所有链接的客户端
	clients := make(map[client]bool)
	for {
    
    
		select {
    
    
		case in := <-input:
			for cli := range clients {
    
    
				cli <- "Server:" + in
			}
		case msg := <-messages:
			// 把所有接收的消息广播给所有的客户
			// 发送消息通道
			for cli := range clients {
    
    
				cli <- msg
			}

		case cli := <-entering:
			clients[cli] = true

		case cli := <-leaving:
			delete(clients, cli)
			close(cli)
		}
	}

}

// 创建一个对外发送消息的新通道
func handleConn(conn net.Conn) {
    
    
	// 对外发送消息的通道
	ch := make(chan string)
	go clientWriter(conn, ch)

	who := conn.RemoteAddr().String()
	ch <- "You are " + who
	messages <- who + " has arrived"
	entering <- ch

	input := bufio.NewScanner(conn)
	for input.Scan() {
    
    
		// 注意,忽略 input.Err() 中可能的错误
		fmt.Println(who + ": " + input.Text() + "\n")
		messages <- who + ": " + input.Text()
	}

	leaving <- ch
	messages <- who + " has left"
	conn.Close()
}

// 将messages回传给客户端
func clientWriter(conn net.Conn, ch <-chan string) {
    
    
	for msg := range ch {
    
    
		// 注意,忽略网络层面的错误
		fmt.Fprintln(conn, msg)
		// fmt.Println(msg)
	}
}

func mustCopy(dst io.Writer, src io.Reader) {
    
    
	if _, err := io.Copy(dst, src); err != nil {
    
    
		log.Fatal(err)
	}
}

客户端

// 一个简单的TCP服务器读/写客户端
package main

import (
	"flag"
	"io"
	"log"
	"net"
	"os"
)

var (
	target string
)

func init() {
    
    
	flag.StringVar(&target, "target", "", "the target (<host>:<port>)")
}

func main() {
    
    
	flag.Parse()

	conn, err := net.Dial("tcp", target)
	if err != nil {
    
    
		log.Fatal(err)
	}
	done := make(chan struct{
    
    })

	// 显示服务端发来的消息
	go func() {
    
    
		io.Copy(os.Stdout, conn) // 注意:忽略错误
		log.Println("done")
		done <- struct{
    
    }{
    
    } // 向主Goroutine发出信号
	}()

	mustCopy(conn, os.Stdin) // 将控制台输入发送到服务端
	conn.Close()
	<-done // 等待后台goroutine完成
}

func mustCopy(dst io.Writer, src io.Reader) {
    
    
	if _, err := io.Copy(dst, src); err != nil {
    
    
		log.Fatal(err)
	}
}

端口转发

package main

import (
    "golang.org/x/crypto/ssh"
    "io"
    "log"
    "net"
)

// main 万物的起源
func main() {
    
    
    sshAddr := "000.00.000.00:22" // 服务器的 ip:ssh端口
    sshUser := "root"             // 用户名,可以新建一个特定用户
    sshPasswd := "6666666.66666"  // 密码
    Remote := "127.0.0.1:3306"    // 转发到远程的端口,未开放但是你想访问的端口
    Listen := "127.0.0.1:3307"    // 需要转发的本地端口,本地访问时使用的端口

    serverClient, err := ssh.Dial("tcp", sshAddr, &ssh.ClientConfig{
    
    
        User:            sshUser,
        Auth:            []ssh.AuthMethod{
    
    ssh.Password(sshPasswd)},
        HostKeyCallback: ssh.InsecureIgnoreHostKey(),
    })
    if err != nil {
    
    
        log.Fatalf("ssh服务器连接异常: %s", err.Error())
    }
    defer serverClient.Close() // 程序执行完关闭连接,养成好习惯
    log.Println("与服务器建立ssh连接成功啦")

    // 监听本地映射端口 这样访问本地端口的时候我就可以发现啦
    listener, err := net.Listen("tcp", Listen)
    if err != nil {
    
    
        log.Fatalf(err.Error())
    }
    defer listener.Close() // 程序执行完关闭连接,养成好习惯

    for {
    
    
        // 接收本地发送的数据
        conn, err := listener.Accept()
        if err != nil {
    
     // 养成好习惯处理错误
            log.Println(err)
            return
        }
        // 开启一个协程去处理这次的消息
        go func(conn net.Conn) {
    
    
            //建立ssh到后端服务的连接
            forwardConn, err := serverClient.Dial("tcp", Remote)
            if err != nil {
    
     // 养成好习惯处理错误
                log.Fatalln(err.Error())
            }
            log.Println("ssh端口映射隧道建立成功")
            defer forwardConn.Close() // 用完要关掉
            // 转发工作
            go io.Copy(forwardConn, conn) // 客户端发送给服务端的数据拷贝给服务端
            io.Copy(conn, forwardConn)    // 服务端发送给客户端的数据拷贝给客户端
        }(conn)
    }
}

或者:

package main

import (
	"flag"
	"fmt"
	"io"
	"log"
	"net"
	"os"
	"os/signal"
)

var (
	target string
	port   int
)

func init() {
    
    
	flag.StringVar(&target, "target", "", "the target (<host>:<port>)")
	flag.IntVar(&port, "port", 7757, "the tunnelthing port")
}

func main() {
    
    
	flag.Parse()

	signals := make(chan os.Signal, 1)
	stop := make(chan bool)
	signal.Notify(signals, os.Interrupt)
	go func() {
    
    
		for _ = range signals {
    
    
			fmt.Println("\nReceived an interrupt, stopping...")
			stop <- true
		}
	}()

	incoming, err := net.Listen("tcp", fmt.Sprintf(":%d", port))
	if err != nil {
    
    
		log.Fatalf("could not start server on %d: %v", port, err)
	}
	fmt.Printf("server running on %d\n", port)

	client, err := incoming.Accept()
	if err != nil {
    
    
		log.Fatal("could not accept client connection", err)
	}
	defer client.Close()
	fmt.Printf("client '%v' connected!\n", client.RemoteAddr())

	target, err := net.Dial("tcp", target)
	if err != nil {
    
    
		log.Fatal("could not connect to target", err)
	}
	defer target.Close()
	fmt.Printf("connection to server %v established!\n", target.RemoteAddr())

	go func() {
    
     io.Copy(target, client) }()
	go func() {
    
     io.Copy(client, target) }()

	<-stop
}

HTTP代理服务器

用golang实现代理服务器也比较容易:

package main  
import (  
 "bytes"  
 "fmt"  
 "io"  
 "log"  
 "net"  
 "net/url"  
 "strings"  
)   
func main() {
    
      
 // tcp 连接,监听 8080 端口  
 l, err := net.Listen("tcp", ":8080")  
 if err != nil {
    
      
  log.Panic(err)  
 }   
 // 死循环,每当遇到连接时,调用 handle  
 for {
    
      
  client, err := l.Accept()  
  if err != nil {
    
      
   log.Panic(err)  
  }  
  go handle(client)  
 }  
}  
func handle(client net.Conn) {
    
     
 if client == nil {
    
      
  return  
 }  
 defer client.Close()  
 log.Printf("remote addr: %v\n", client.RemoteAddr())  
 // 用来存放客户端数据的缓冲区  
 var b [1024]byte  
 //从客户端获取数据  
 n, err := client.Read(b[:])  
 if err != nil {
    
      
  log.Println(err)  
  return 
 }  
 var method, URL, address string  
 // 从客户端数据读入 method,url  
 fmt.Sscanf(string(b[:bytes.IndexByte(b[:], '\n')]), "%s%s", &method, &URL)  
 hostPortURL, err := url.Parse(URL)  
 if err != nil {
    
      
  log.Println(err)  
  return  
 }  
 // 如果方法是 CONNECT,则为 https 协议  
 if method == "CONNECT" {
    
      
  address = hostPortURL.Scheme + ":" + hostPortURL.Opaque  
 } else {
    
     //否则为 http 协议  
  address = hostPortURL.Host  
  // 如果 host 不带端口,则默认为 80  
  if strings.Index(hostPortURL.Host, ":") == -1 {
    
     //host 不带端口, 默认 80  
   address = hostPortURL.Host + ":80"  
  }  
 }  
 //获得了请求的 host 和 port,向服务端发起 tcp 连接  
 server, err := net.Dial("tcp", address)  
 if err != nil {
    
      
  log.Println(err)  
  return  
 }  
 //如果使用 https 协议,需先向客户端表示连接建立完毕  
 if method == "CONNECT" {
    
      
  fmt.Fprint(client, "HTTP/1.1 200 Connection established\r\n\r\n")  
 } else {
    
     //如果使用 http 协议,需将从客户端得到的 http 请求转发给服务端  
  server.Write(b[:n])  
 }  
 //将客户端的请求转发至服务端,将服务端的响应转发给客户端。io.Copy 为阻塞函数,文件描述符不关闭就不停止  
 go io.Copy(server, client)  
 io.Copy(client, server)  
} 

内网穿透

内网穿透让你在家里机器上运行应用,然后被外界访问。当然,得有一台公网机器做转发。

  • frp
  • https://github.com/linzhepeng/go-NAT.git

断点续传

下载分片:

func (d *FileDownloader) downloadPart(c filePart) error {
    
    
    headers := map[string]string{
    
    
        "User-Agent": userAgent,
        "Range":      fmt.Sprintf("bytes=%v-%v", c.From, c.To),
    }
    // 或得一个 request
    r, err := getNewRequest(d.url, "GET", headers)
    if err != nil {
    
    
        return err
    }
    // 打印要下载的分片信息
    log.Printf("开始[%d]下载from:%d to:%d\n", c.Index, c.From, c.To)
    resp, err := http.DefaultClient.Do(r)
    if resp.StatusCode > 299 {
    
    
        return errors.New(fmt.Sprintf("服务器错误状态码: %v", resp.StatusCode))
    }
    // 最后关闭文件
    defer func(Body io.ReadCloser) {
    
    
        err := Body.Close()
        if err != nil {
    
    
        }
    }(resp.Body)
    // 读取 Body 的响应数据
    bs, err := io.ReadAll(resp.Body)
    if err != nil {
    
    
        return err
    }
    if len(bs) != (c.To - c.From + 1) {
    
    
        return errors.New("下载文件分片长度错误")
    }
    c.Data = bs
    // c完成了后就加入到下载器中
    d.doneFilePart[c.Index] = c
    return nil
}

多线程下载:

func (d *FileDownloader) Run() error {
    
    
    // 获取文件大小
    fileTotalSize, err := d.getHeaderInfo()
    if err != nil {
    
    
        fmt.Printf("hello!!")
        return err
    }
    d.fileSize = fileTotalSize
    jobs := make([]filePart, d.totalPart)
    // 这里进行均分
    eachSize := fileTotalSize / d.totalPart
    for i := range jobs {
    
    
        jobs[i].Index = i
        // 计算 form
        if i == 0 {
    
    
            jobs[i].From = 0
        } else {
    
    
            jobs[i].From = jobs[i-1].To + 1
        }
        // 计算 to
        if i < d.totalPart-1 {
    
    
            jobs[i].To = jobs[i].From + eachSize
        } else {
    
    
            // 最后一个filePart
            jobs[i].To = fileTotalSize - 1
        }
    }
    // 多线程下载
    var wg sync.WaitGroup
    for _, j := range jobs {
    
    
        wg.Add(1)
        go func(job filePart) {
    
    
            defer wg.Done()
            err := d.downloadPart(job)
            if err != nil {
    
    
                log.Println("下载文件失败:", err, job)
            }
        }(j)
    }
    wg.Wait()
    return d.mergeFileParts()
}

合并文件:

// 合并要下载的文件
func (d *FileDownloader) mergeFileParts() error {
    
    
    path := filepath.Join(d.outputDir, d.outputFileName)
    log.Println("开始合并文件")
    // 创建文件
    mergedFile, err := os.Create(path)
    if err != nil {
    
    
        return err
    }
    // 最后关闭文件
    defer func(mergedFile *os.File) {
    
    
        err := mergedFile.Close()
        if err != nil {
    
    
        }
    }(mergedFile)
    // sha256是一种密码散列函数,说白了它就是一个哈希函数。
    //对于任意长度的消息,SHA256都会产生一个256bit长度的散列值,
    //称为消息摘要,可以用一个长度为64的十六进制字符串表示。
    fileMd5 := sha256.New()
    totalSize := 0
    // 合并的工作
    for _, s := range d.doneFilePart {
    
    
        _, err := mergedFile.Write(s.Data)
        if err != nil {
    
    
            fmt.Printf("error when merge file: %v\n", err)
        }
        fileMd5.Write(s.Data)    // 更新哈希值
        totalSize += len(s.Data) // 更新长度
    }
    // 校验文件完整性
    if totalSize != d.fileSize {
    
    
        return errors.New("文件不完整")
    }
    // 检验 MD5
    if d.md5 == "" {
    
    
        // 将整个文件进行了 Sum 运算, 该函数返回一个 16 进制串,转成字符串之后,
        // 和 d.md5比较,起到了一个校验的效果
        if hex.EncodeToString(fileMd5.Sum(nil)) != d.md5 {
    
    
            return errors.New("文件损坏")
        } else {
    
    
            log.Println("文件SHA-256校验成功")
        }
    }
    return nil
}

参考链接

猜你喜欢

转载自blog.csdn.net/jgku/article/details/132030525