前言
网络编程是Golang的重要使用场景,它能以短小精悍的代码实现强大的功能,为你打造一把瑞士军刀。
特点: goroutine+阻塞通信的网络通信模型
Golang网络编程简述
目前主流TCP server一般均采用的都是”Non-Block + I/O多路复用”(有的也结合了多线程、多进程)。不过I/O多路复用也给使用者带来了不小的复杂度,以至于后续出现了许多高性能的I/O多路复用框架, 比如libevent、libev、libuv等,以帮助开发者简化开发复杂性,降低心智负担。不过Go的设计者似乎认为I/O多路复用的这种通过回调机制割裂控制流 的方式依旧复杂,且有悖于“一般逻辑”设计,为此Go语言将该“复杂性”隐藏在Runtime中了:Go开发者无需关注socket是否是 non-block的,也无需亲自注册文件描述符的回调,只需在每个连接对应的goroutine中以“block I/O”的方式对待socket处理即可。
基本API:
- 创建监听服务器: net.Listen()
- 接受连接: listen.Accept()
- 创建连接: net.Dial()或net.DialTimeout()
获得连接conn后,即可以在上进行读写,以完成业务逻辑。Go runtime隐藏了I/O多路复用的复杂性。语言使用者只需采用goroutine+Block I/O的模式即可满足大部分场景需求。
//go-tcpsock/server.go
func HandleConn(conn net.Conn) {
defer conn.Close()
for {
// read from the connection
// ... ...
// write to the connection
//... ...
}
}
func main() {
listen, err := net.Listen("tcp", ":8888")
if err != nil {
fmt.Println("listen error: ", err)
return
}
for {
conn, err := listen.Accept()
if err != nil {
fmt.Println("accept error: ", err)
break
}
// start a new goroutine to handle the new connection
go HandleConn(conn)
}
}
简易的HTTP服务器
package main
import (
"fmt"
"log"
"net/http"
)
// w表示response对象,返回给客户端的内容都在对象里处理
// r表示客户端请求对象,包含了请求头,请求参数等等
func index(w http.ResponseWriter, r *http.Request) {
// 往w里写入内容,就会在浏览器里输出
fmt.Fprintf(w, "Hello golang http!")
}
func main() {
// 设置路由,如果访问/,则调用index方法
http.HandleFunc("/", index)
// 启动web服务,监听9090端口
err := http.ListenAndServe(":9090", nil)
if err != nil {
log.Fatal("ListenAndServe: ", err)
}
}
最简单的聊天服务程序
服务端
package main
import (
"bufio"
"flag"
"fmt"
"io"
"log"
"net"
"os"
)
var port int
func init() {
flag.IntVar(&port, "port", 7575, "the tunnelthing port")
}
func main() {
flag.Parse()
listener, err := net.Listen("tcp", fmt.Sprintf(":%d", port))
if err != nil {
log.Fatal(err)
}
input := make(chan string)
go getTerminalInput(input)
go broadcaster(input)
for {
conn, err := listener.Accept()
if err != nil {
log.Print(err)
continue
}
go handleConn(conn)
}
}
// 对外发送消息的通道
type client chan<- string
var (
entering = make(chan client)
leaving = make(chan client)
// 所有接收的客户消息
messages = make(chan string)
)
func getTerminalInput(input chan string) {
scanner := bufio.NewScanner(os.Stdin)
for {
scanner.Scan()
input <- scanner.Text()
}
}
func broadcaster(input chan string) {
// 所有链接的客户端
clients := make(map[client]bool)
for {
select {
case in := <-input:
for cli := range clients {
cli <- "Server:" + in
}
case msg := <-messages:
// 把所有接收的消息广播给所有的客户
// 发送消息通道
for cli := range clients {
cli <- msg
}
case cli := <-entering:
clients[cli] = true
case cli := <-leaving:
delete(clients, cli)
close(cli)
}
}
}
// 创建一个对外发送消息的新通道
func handleConn(conn net.Conn) {
// 对外发送消息的通道
ch := make(chan string)
go clientWriter(conn, ch)
who := conn.RemoteAddr().String()
ch <- "You are " + who
messages <- who + " has arrived"
entering <- ch
input := bufio.NewScanner(conn)
for input.Scan() {
// 注意,忽略 input.Err() 中可能的错误
fmt.Println(who + ": " + input.Text() + "\n")
messages <- who + ": " + input.Text()
}
leaving <- ch
messages <- who + " has left"
conn.Close()
}
// 将messages回传给客户端
func clientWriter(conn net.Conn, ch <-chan string) {
for msg := range ch {
// 注意,忽略网络层面的错误
fmt.Fprintln(conn, msg)
// fmt.Println(msg)
}
}
func mustCopy(dst io.Writer, src io.Reader) {
if _, err := io.Copy(dst, src); err != nil {
log.Fatal(err)
}
}
客户端
// 一个简单的TCP服务器读/写客户端
package main
import (
"flag"
"io"
"log"
"net"
"os"
)
var (
target string
)
func init() {
flag.StringVar(&target, "target", "", "the target (<host>:<port>)")
}
func main() {
flag.Parse()
conn, err := net.Dial("tcp", target)
if err != nil {
log.Fatal(err)
}
done := make(chan struct{
})
// 显示服务端发来的消息
go func() {
io.Copy(os.Stdout, conn) // 注意:忽略错误
log.Println("done")
done <- struct{
}{
} // 向主Goroutine发出信号
}()
mustCopy(conn, os.Stdin) // 将控制台输入发送到服务端
conn.Close()
<-done // 等待后台goroutine完成
}
func mustCopy(dst io.Writer, src io.Reader) {
if _, err := io.Copy(dst, src); err != nil {
log.Fatal(err)
}
}
端口转发
package main
import (
"golang.org/x/crypto/ssh"
"io"
"log"
"net"
)
// main 万物的起源
func main() {
sshAddr := "000.00.000.00:22" // 服务器的 ip:ssh端口
sshUser := "root" // 用户名,可以新建一个特定用户
sshPasswd := "6666666.66666" // 密码
Remote := "127.0.0.1:3306" // 转发到远程的端口,未开放但是你想访问的端口
Listen := "127.0.0.1:3307" // 需要转发的本地端口,本地访问时使用的端口
serverClient, err := ssh.Dial("tcp", sshAddr, &ssh.ClientConfig{
User: sshUser,
Auth: []ssh.AuthMethod{
ssh.Password(sshPasswd)},
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
})
if err != nil {
log.Fatalf("ssh服务器连接异常: %s", err.Error())
}
defer serverClient.Close() // 程序执行完关闭连接,养成好习惯
log.Println("与服务器建立ssh连接成功啦")
// 监听本地映射端口 这样访问本地端口的时候我就可以发现啦
listener, err := net.Listen("tcp", Listen)
if err != nil {
log.Fatalf(err.Error())
}
defer listener.Close() // 程序执行完关闭连接,养成好习惯
for {
// 接收本地发送的数据
conn, err := listener.Accept()
if err != nil {
// 养成好习惯处理错误
log.Println(err)
return
}
// 开启一个协程去处理这次的消息
go func(conn net.Conn) {
//建立ssh到后端服务的连接
forwardConn, err := serverClient.Dial("tcp", Remote)
if err != nil {
// 养成好习惯处理错误
log.Fatalln(err.Error())
}
log.Println("ssh端口映射隧道建立成功")
defer forwardConn.Close() // 用完要关掉
// 转发工作
go io.Copy(forwardConn, conn) // 客户端发送给服务端的数据拷贝给服务端
io.Copy(conn, forwardConn) // 服务端发送给客户端的数据拷贝给客户端
}(conn)
}
}
或者:
package main
import (
"flag"
"fmt"
"io"
"log"
"net"
"os"
"os/signal"
)
var (
target string
port int
)
func init() {
flag.StringVar(&target, "target", "", "the target (<host>:<port>)")
flag.IntVar(&port, "port", 7757, "the tunnelthing port")
}
func main() {
flag.Parse()
signals := make(chan os.Signal, 1)
stop := make(chan bool)
signal.Notify(signals, os.Interrupt)
go func() {
for _ = range signals {
fmt.Println("\nReceived an interrupt, stopping...")
stop <- true
}
}()
incoming, err := net.Listen("tcp", fmt.Sprintf(":%d", port))
if err != nil {
log.Fatalf("could not start server on %d: %v", port, err)
}
fmt.Printf("server running on %d\n", port)
client, err := incoming.Accept()
if err != nil {
log.Fatal("could not accept client connection", err)
}
defer client.Close()
fmt.Printf("client '%v' connected!\n", client.RemoteAddr())
target, err := net.Dial("tcp", target)
if err != nil {
log.Fatal("could not connect to target", err)
}
defer target.Close()
fmt.Printf("connection to server %v established!\n", target.RemoteAddr())
go func() {
io.Copy(target, client) }()
go func() {
io.Copy(client, target) }()
<-stop
}
HTTP代理服务器
用golang实现代理服务器也比较容易:
package main
import (
"bytes"
"fmt"
"io"
"log"
"net"
"net/url"
"strings"
)
func main() {
// tcp 连接,监听 8080 端口
l, err := net.Listen("tcp", ":8080")
if err != nil {
log.Panic(err)
}
// 死循环,每当遇到连接时,调用 handle
for {
client, err := l.Accept()
if err != nil {
log.Panic(err)
}
go handle(client)
}
}
func handle(client net.Conn) {
if client == nil {
return
}
defer client.Close()
log.Printf("remote addr: %v\n", client.RemoteAddr())
// 用来存放客户端数据的缓冲区
var b [1024]byte
//从客户端获取数据
n, err := client.Read(b[:])
if err != nil {
log.Println(err)
return
}
var method, URL, address string
// 从客户端数据读入 method,url
fmt.Sscanf(string(b[:bytes.IndexByte(b[:], '\n')]), "%s%s", &method, &URL)
hostPortURL, err := url.Parse(URL)
if err != nil {
log.Println(err)
return
}
// 如果方法是 CONNECT,则为 https 协议
if method == "CONNECT" {
address = hostPortURL.Scheme + ":" + hostPortURL.Opaque
} else {
//否则为 http 协议
address = hostPortURL.Host
// 如果 host 不带端口,则默认为 80
if strings.Index(hostPortURL.Host, ":") == -1 {
//host 不带端口, 默认 80
address = hostPortURL.Host + ":80"
}
}
//获得了请求的 host 和 port,向服务端发起 tcp 连接
server, err := net.Dial("tcp", address)
if err != nil {
log.Println(err)
return
}
//如果使用 https 协议,需先向客户端表示连接建立完毕
if method == "CONNECT" {
fmt.Fprint(client, "HTTP/1.1 200 Connection established\r\n\r\n")
} else {
//如果使用 http 协议,需将从客户端得到的 http 请求转发给服务端
server.Write(b[:n])
}
//将客户端的请求转发至服务端,将服务端的响应转发给客户端。io.Copy 为阻塞函数,文件描述符不关闭就不停止
go io.Copy(server, client)
io.Copy(client, server)
}
内网穿透
内网穿透让你在家里机器上运行应用,然后被外界访问。当然,得有一台公网机器做转发。
- frp
- https://github.com/linzhepeng/go-NAT.git
断点续传
下载分片:
func (d *FileDownloader) downloadPart(c filePart) error {
headers := map[string]string{
"User-Agent": userAgent,
"Range": fmt.Sprintf("bytes=%v-%v", c.From, c.To),
}
// 或得一个 request
r, err := getNewRequest(d.url, "GET", headers)
if err != nil {
return err
}
// 打印要下载的分片信息
log.Printf("开始[%d]下载from:%d to:%d\n", c.Index, c.From, c.To)
resp, err := http.DefaultClient.Do(r)
if resp.StatusCode > 299 {
return errors.New(fmt.Sprintf("服务器错误状态码: %v", resp.StatusCode))
}
// 最后关闭文件
defer func(Body io.ReadCloser) {
err := Body.Close()
if err != nil {
}
}(resp.Body)
// 读取 Body 的响应数据
bs, err := io.ReadAll(resp.Body)
if err != nil {
return err
}
if len(bs) != (c.To - c.From + 1) {
return errors.New("下载文件分片长度错误")
}
c.Data = bs
// c完成了后就加入到下载器中
d.doneFilePart[c.Index] = c
return nil
}
多线程下载:
func (d *FileDownloader) Run() error {
// 获取文件大小
fileTotalSize, err := d.getHeaderInfo()
if err != nil {
fmt.Printf("hello!!")
return err
}
d.fileSize = fileTotalSize
jobs := make([]filePart, d.totalPart)
// 这里进行均分
eachSize := fileTotalSize / d.totalPart
for i := range jobs {
jobs[i].Index = i
// 计算 form
if i == 0 {
jobs[i].From = 0
} else {
jobs[i].From = jobs[i-1].To + 1
}
// 计算 to
if i < d.totalPart-1 {
jobs[i].To = jobs[i].From + eachSize
} else {
// 最后一个filePart
jobs[i].To = fileTotalSize - 1
}
}
// 多线程下载
var wg sync.WaitGroup
for _, j := range jobs {
wg.Add(1)
go func(job filePart) {
defer wg.Done()
err := d.downloadPart(job)
if err != nil {
log.Println("下载文件失败:", err, job)
}
}(j)
}
wg.Wait()
return d.mergeFileParts()
}
合并文件:
// 合并要下载的文件
func (d *FileDownloader) mergeFileParts() error {
path := filepath.Join(d.outputDir, d.outputFileName)
log.Println("开始合并文件")
// 创建文件
mergedFile, err := os.Create(path)
if err != nil {
return err
}
// 最后关闭文件
defer func(mergedFile *os.File) {
err := mergedFile.Close()
if err != nil {
}
}(mergedFile)
// sha256是一种密码散列函数,说白了它就是一个哈希函数。
//对于任意长度的消息,SHA256都会产生一个256bit长度的散列值,
//称为消息摘要,可以用一个长度为64的十六进制字符串表示。
fileMd5 := sha256.New()
totalSize := 0
// 合并的工作
for _, s := range d.doneFilePart {
_, err := mergedFile.Write(s.Data)
if err != nil {
fmt.Printf("error when merge file: %v\n", err)
}
fileMd5.Write(s.Data) // 更新哈希值
totalSize += len(s.Data) // 更新长度
}
// 校验文件完整性
if totalSize != d.fileSize {
return errors.New("文件不完整")
}
// 检验 MD5
if d.md5 == "" {
// 将整个文件进行了 Sum 运算, 该函数返回一个 16 进制串,转成字符串之后,
// 和 d.md5比较,起到了一个校验的效果
if hex.EncodeToString(fileMd5.Sum(nil)) != d.md5 {
return errors.New("文件损坏")
} else {
log.Println("文件SHA-256校验成功")
}
}
return nil
}