Compare commits
23 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
4e17fee681
|
|||
|
a8eed30db5
|
|||
| c1eaf43058 | |||
| 9f5aca124d | |||
| 54958724e7 | |||
| 7a17672149 | |||
| 44b807d3d1 | |||
| 0d847462b3 | |||
| deed4207ea | |||
| f6363fed07 | |||
| 1de78f2f06 | |||
| d0122a9771 | |||
| 319518d71d | |||
| be3df9703e | |||
| b92288bbc9 | |||
| 0805549006 | |||
| 033272f38a | |||
|
93b756d9fb
|
|||
|
d71eacdc91
|
|||
|
747fc52c44
|
|||
| ce3ebbbf8a | |||
|
66c8abbcea
|
|||
| b4bffa978c |
@@ -0,0 +1 @@
|
||||
.idea
|
||||
+464
@@ -0,0 +1,464 @@
|
||||
package starnet
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestUrlEncodeRaw(t *testing.T) {
|
||||
input := "hello world!@#$%^&*()_+-=~`"
|
||||
expected := "hello%20world%21%40%23%24%25%5E%26%2A%28%29_%2B-%3D~%60"
|
||||
result := UrlEncodeRaw(input)
|
||||
if result != expected {
|
||||
t.Errorf("UrlEncodeRaw(%q) = %q; want %q", input, result, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUrlEncode(t *testing.T) {
|
||||
input := "hello world!@#$%^&*()_+-=~`"
|
||||
expected := `hello+world%21%40%23%24%25%5E%26%2A%28%29_%2B-%3D~%60`
|
||||
result := UrlEncode(input)
|
||||
if result != expected {
|
||||
t.Errorf("UrlEncode(%q) = %q; want %q", input, result, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUrlDecode(t *testing.T) {
|
||||
input := "hello%20world%21%40%23%24%25%5E%26*%28%29_%2B-%3D~%60"
|
||||
expected := "hello world!@#$%^&*()_+-=~`"
|
||||
result, err := UrlDecode(input)
|
||||
if err != nil {
|
||||
t.Errorf("UrlDecode(%q) returned error: %v", input, err)
|
||||
}
|
||||
if result != expected {
|
||||
t.Errorf("UrlDecode(%q) = %q; want %q", input, result, expected)
|
||||
}
|
||||
|
||||
// Test for error case
|
||||
invalidInput := "%zz"
|
||||
_, err = UrlDecode(invalidInput)
|
||||
if err == nil {
|
||||
t.Errorf("UrlDecode(%q) expected error, got nil", invalidInput)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildPostForm_WithValidInput(t *testing.T) {
|
||||
input := map[string]string{
|
||||
"key1": "value1",
|
||||
"key2": "value2",
|
||||
}
|
||||
|
||||
expected := []byte("key1=value1&key2=value2")
|
||||
|
||||
result := BuildPostForm(input)
|
||||
|
||||
if string(result) != string(expected) {
|
||||
t.Errorf("BuildPostForm(%v) = %v; want %v", input, result, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildPostForm_WithEmptyInput(t *testing.T) {
|
||||
input := map[string]string{}
|
||||
|
||||
expected := []byte("")
|
||||
|
||||
result := BuildPostForm(input)
|
||||
|
||||
if string(result) != string(expected) {
|
||||
t.Errorf("BuildPostForm(%v) = %v; want %v", input, result, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildPostForm_WithNilInput(t *testing.T) {
|
||||
var input map[string]string
|
||||
|
||||
expected := []byte("")
|
||||
|
||||
result := BuildPostForm(input)
|
||||
|
||||
if string(result) != string(expected) {
|
||||
t.Errorf("BuildPostForm(%v) = %v; want %v", input, result, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetRequest(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
rw.Write([]byte(`OK`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
resp, err := Get(server.URL)
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
body := resp.Body().String()
|
||||
if body != "OK" {
|
||||
t.Errorf("Expected OK, got %v", body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostRequest(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
if req.Method != http.MethodPost {
|
||||
t.Errorf("Expected 'POST', got %v", req.Method)
|
||||
}
|
||||
rw.Write([]byte(`OK`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
resp, err := Post(server.URL)
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
body := resp.Body().String()
|
||||
if body != "OK" {
|
||||
t.Errorf("Expected OK, got %v", body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOptionsRequestWithValidInput(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
if req.Method != http.MethodOptions {
|
||||
t.Errorf("Expected 'OPTIONS', got %v", req.Method)
|
||||
}
|
||||
rw.Write([]byte(`OK`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
resp, err := Options(server.URL)
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
body := resp.Body().String()
|
||||
if body != "OK" {
|
||||
t.Errorf("Expected OK, got %v", body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPutRequestWithValidInput(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
if req.Method != http.MethodPut {
|
||||
t.Errorf("Expected 'PUT', got %v", req.Method)
|
||||
}
|
||||
rw.Write([]byte(`OK`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
resp, err := Put(server.URL)
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
body := resp.Body().String()
|
||||
if body != "OK" {
|
||||
t.Errorf("Expected OK, got %v", body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteRequestWithValidInput(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
if req.Method != http.MethodDelete {
|
||||
t.Errorf("Expected 'DELETE', got %v", req.Method)
|
||||
}
|
||||
rw.Write([]byte(`OK`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
resp, err := Delete(server.URL)
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
body := resp.Body().String()
|
||||
if body != "OK" {
|
||||
t.Errorf("Expected OK, got %v", body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHeadRequestWithValidInput(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
if req.Method != http.MethodHead {
|
||||
t.Errorf("Expected 'HEAD', got %v", req.Method)
|
||||
}
|
||||
rw.Write([]byte(`OK`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
resp, err := Head(server.URL)
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
body := resp.Body().String()
|
||||
if body == "OK" {
|
||||
t.Errorf("Expected , got %v", body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPatchRequestWithValidInput(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
if req.Method != http.MethodPatch {
|
||||
t.Errorf("Expected 'PATCH', got %v", req.Method)
|
||||
}
|
||||
rw.Write([]byte(`OK`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
resp, err := Patch(server.URL)
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
body := resp.Body().String()
|
||||
if body != "OK" {
|
||||
t.Errorf("Expected OK, got %v", body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTraceRequestWithValidInput(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
if req.Method != http.MethodTrace {
|
||||
t.Errorf("Expected 'TRACE', got %v", req.Method)
|
||||
}
|
||||
rw.Write([]byte(`OK`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
resp, err := Trace(server.URL)
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
body := resp.Body().String()
|
||||
if body != "OK" {
|
||||
t.Errorf("Expected OK, got %v", body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnectRequestWithValidInput(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
if req.Method != http.MethodConnect {
|
||||
t.Errorf("Expected 'CONNECT', got %v", req.Method)
|
||||
}
|
||||
rw.Write([]byte(`OK`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
resp, err := Connect(server.URL)
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
body := resp.Body().String()
|
||||
if body != "OK" {
|
||||
t.Errorf("Expected OK, got %v", body)
|
||||
}
|
||||
}
|
||||
func TestMethodReturnsCorrectValue(t *testing.T) {
|
||||
req := NewReq("https://example.com")
|
||||
req.SetMethodNoError("GET")
|
||||
if req.Method() != "GET" {
|
||||
t.Errorf("Expected 'GET', got %v", req.Method())
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetMethodHandlesInvalidInput(t *testing.T) {
|
||||
req := NewReq("https://example.com")
|
||||
err := req.SetMethod("我是谁")
|
||||
if err == nil {
|
||||
t.Errorf("Expected error, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetMethodNoErrorSetsMethodCorrectly(t *testing.T) {
|
||||
req := NewReq("https://example.com")
|
||||
req.SetMethodNoError("POST")
|
||||
if req.Method() != "POST" {
|
||||
t.Errorf("Expected 'POST', got %v", req.Method())
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetMethodNoErrorIgnoresInvalidInput(t *testing.T) {
|
||||
req := NewReq("https://example.com")
|
||||
req.SetMethodNoError("你是谁")
|
||||
if req.Method() != "GET" {
|
||||
t.Errorf("Expected '', got %v", req.Method())
|
||||
}
|
||||
}
|
||||
|
||||
func TestUriReturnsCorrectValue(t *testing.T) {
|
||||
req := NewReq("https://example.com")
|
||||
if req.Uri() != "https://example.com" {
|
||||
t.Errorf("Expected 'https://example.com', got %v", req.Uri())
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetUriHandlesValidInput(t *testing.T) {
|
||||
req := NewReq("https://example.com")
|
||||
err := req.SetUri("https://newexample.com")
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
if req.Uri() != "https://newexample.com" {
|
||||
t.Errorf("Expected 'https://newexample.com', got %v", req.Uri())
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetUriHandlesInvalidInput(t *testing.T) {
|
||||
req := NewReq("https://example.com")
|
||||
err := req.SetUri("://invalidurl")
|
||||
if err == nil {
|
||||
t.Errorf("Expected error, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetUriNoErrorSetsUriCorrectly(t *testing.T) {
|
||||
req := NewReq("https://example.com")
|
||||
req.SetUriNoError("https://newexample.com")
|
||||
if req.Uri() != "https://newexample.com" {
|
||||
t.Errorf("Expected 'https://newexample.com', got %v", req.Uri())
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetUriNoErrorIgnoresInvalidInput(t *testing.T) {
|
||||
req := NewReq("https://example.com")
|
||||
req.SetUriNoError("://invalidurl")
|
||||
if req.Uri() != "https://example.com" {
|
||||
t.Errorf("Expected 'https://example.com', got %v", req.Uri())
|
||||
}
|
||||
}
|
||||
|
||||
type postmanReply struct {
|
||||
Args struct {
|
||||
} `json:"args"`
|
||||
Form map[string]string `json:"form"`
|
||||
Headers map[string]string `json:"headers"`
|
||||
Url string `json:"url"`
|
||||
}
|
||||
|
||||
func TestGet(t *testing.T) {
|
||||
var reply postmanReply
|
||||
resp, err := NewReq("https://postman-echo.com/get").
|
||||
AddHeader("hello", "nononmo").
|
||||
SetAutoCalcContentLength(true).
|
||||
Do()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
err = resp.Body().Unmarshal(&reply)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
fmt.Println(resp.Body().String())
|
||||
fmt.Println(reply.Headers)
|
||||
fmt.Println(resp.Cookies())
|
||||
}
|
||||
|
||||
type testData struct {
|
||||
name string
|
||||
args *Request
|
||||
want func(*Response) error
|
||||
wantErr bool
|
||||
}
|
||||
|
||||
func headerTestData() []testData {
|
||||
return []testData{
|
||||
{
|
||||
name: "addHeader",
|
||||
args: NewReq("https://postman-echo.com/get").
|
||||
AddHeader("b612", "test-data").
|
||||
AddHeader("b612", "test-header").
|
||||
AddSimpleCookie("b612", "test-cookie").
|
||||
SetHeader("User-Agent", "starnet test"),
|
||||
want: func(resp *Response) error {
|
||||
//fmt.Println(resp.Body().String())
|
||||
if resp == nil {
|
||||
return fmt.Errorf("response is nil")
|
||||
}
|
||||
if resp.StatusCode != 200 {
|
||||
return fmt.Errorf("status code is %d", resp.StatusCode)
|
||||
}
|
||||
var reply postmanReply
|
||||
err := resp.Body().Unmarshal(&reply)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if reply.Headers["b612"] != "test-data, test-header" {
|
||||
return fmt.Errorf("header not found")
|
||||
}
|
||||
if reply.Headers["user-agent"] != "starnet test" {
|
||||
return fmt.Errorf("user-agent not found")
|
||||
}
|
||||
if reply.Headers["cookie"] != "b612=test-cookie" {
|
||||
return fmt.Errorf("cookie not found")
|
||||
}
|
||||
return nil
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "postForm",
|
||||
args: NewSimpleRequest("https://postman-echo.com/post", "POST").
|
||||
AddHeader("b612", "test-data").
|
||||
AddHeader("b612", "test-header").
|
||||
AddSimpleCookie("b612", "test-cookie").
|
||||
SetHeader("User-Agent", "starnet test").
|
||||
//SetHeader("Content-Type", "application/x-www-form-urlencoded").
|
||||
AddFormData("hello", "world").
|
||||
AddFormData("hello2", "world2").
|
||||
SetMethodNoError("POST"),
|
||||
want: func(resp *Response) error {
|
||||
//fmt.Println(resp.Body().String())
|
||||
if resp == nil {
|
||||
return fmt.Errorf("response is nil")
|
||||
}
|
||||
if resp.StatusCode != 200 {
|
||||
return fmt.Errorf("status code is %d", resp.StatusCode)
|
||||
}
|
||||
var reply postmanReply
|
||||
err := resp.Body().Unmarshal(&reply)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if reply.Headers["b612"] != "test-data, test-header" {
|
||||
return fmt.Errorf("header not found")
|
||||
}
|
||||
if reply.Headers["user-agent"] != "starnet test" {
|
||||
return fmt.Errorf("user-agent not found")
|
||||
}
|
||||
if reply.Headers["cookie"] != "b612=test-cookie" {
|
||||
return fmt.Errorf("cookie not found")
|
||||
}
|
||||
if reply.Form["hello"] != "world" {
|
||||
return fmt.Errorf("form data not found")
|
||||
}
|
||||
if reply.Form["hello2"] != "world2" {
|
||||
return fmt.Errorf("form data not found")
|
||||
}
|
||||
return nil
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
}
|
||||
func TestCurl(t *testing.T) {
|
||||
for _, tt := range headerTestData() {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := Curl(tt.args)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Curl() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if tt.want != nil {
|
||||
if err := tt.want(got); err != nil {
|
||||
t.Errorf("Curl() = %v", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
+120
@@ -0,0 +1,120 @@
|
||||
package starnet
|
||||
|
||||
import "strings"
|
||||
|
||||
var isTokenTable = [127]bool{
|
||||
'!': true,
|
||||
'#': true,
|
||||
'$': true,
|
||||
'%': true,
|
||||
'&': true,
|
||||
'\'': true,
|
||||
'*': true,
|
||||
'+': true,
|
||||
'-': true,
|
||||
'.': true,
|
||||
'0': true,
|
||||
'1': true,
|
||||
'2': true,
|
||||
'3': true,
|
||||
'4': true,
|
||||
'5': true,
|
||||
'6': true,
|
||||
'7': true,
|
||||
'8': true,
|
||||
'9': true,
|
||||
'A': true,
|
||||
'B': true,
|
||||
'C': true,
|
||||
'D': true,
|
||||
'E': true,
|
||||
'F': true,
|
||||
'G': true,
|
||||
'H': true,
|
||||
'I': true,
|
||||
'J': true,
|
||||
'K': true,
|
||||
'L': true,
|
||||
'M': true,
|
||||
'N': true,
|
||||
'O': true,
|
||||
'P': true,
|
||||
'Q': true,
|
||||
'R': true,
|
||||
'S': true,
|
||||
'T': true,
|
||||
'U': true,
|
||||
'W': true,
|
||||
'V': true,
|
||||
'X': true,
|
||||
'Y': true,
|
||||
'Z': true,
|
||||
'^': true,
|
||||
'_': true,
|
||||
'`': true,
|
||||
'a': true,
|
||||
'b': true,
|
||||
'c': true,
|
||||
'd': true,
|
||||
'e': true,
|
||||
'f': true,
|
||||
'g': true,
|
||||
'h': true,
|
||||
'i': true,
|
||||
'j': true,
|
||||
'k': true,
|
||||
'l': true,
|
||||
'm': true,
|
||||
'n': true,
|
||||
'o': true,
|
||||
'p': true,
|
||||
'q': true,
|
||||
'r': true,
|
||||
's': true,
|
||||
't': true,
|
||||
'u': true,
|
||||
'v': true,
|
||||
'w': true,
|
||||
'x': true,
|
||||
'y': true,
|
||||
'z': true,
|
||||
'|': true,
|
||||
'~': true,
|
||||
}
|
||||
|
||||
func IsTokenRune(r rune) bool {
|
||||
i := int(r)
|
||||
return i < len(isTokenTable) && isTokenTable[i]
|
||||
}
|
||||
|
||||
func validMethod(method string) bool {
|
||||
/*
|
||||
Method = "OPTIONS" ; Section 9.2
|
||||
| "GET" ; Section 9.3
|
||||
| "HEAD" ; Section 9.4
|
||||
| "POST" ; Section 9.5
|
||||
| "PUT" ; Section 9.6
|
||||
| "DELETE" ; Section 9.7
|
||||
| "TRACE" ; Section 9.8
|
||||
| "CONNECT" ; Section 9.9
|
||||
| extension-method
|
||||
extension-method = token
|
||||
token = 1*<any CHAR except CTLs or separators>
|
||||
*/
|
||||
return len(method) > 0 && strings.IndexFunc(method, isNotToken) == -1
|
||||
}
|
||||
|
||||
func isNotToken(r rune) bool {
|
||||
return !IsTokenRune(r)
|
||||
}
|
||||
|
||||
func hasPort(s string) bool { return strings.LastIndex(s, ":") > strings.LastIndex(s, "]") }
|
||||
|
||||
// removeEmptyPort strips the empty port in ":port" to ""
|
||||
// as mandated by RFC 3986 Section 6.2.3.
|
||||
func removeEmptyPort(host string) string {
|
||||
if hasPort(host) {
|
||||
return strings.TrimSuffix(host, ":")
|
||||
}
|
||||
return host
|
||||
}
|
||||
@@ -0,0 +1,110 @@
|
||||
package starnet
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
type ICMP struct {
|
||||
Type uint8
|
||||
Code uint8
|
||||
CheckSum uint16
|
||||
Identifier uint16
|
||||
SequenceNum uint16
|
||||
}
|
||||
|
||||
func getICMP(seq uint16) ICMP {
|
||||
icmp := ICMP{
|
||||
Type: 8,
|
||||
Code: 0,
|
||||
CheckSum: 0,
|
||||
Identifier: 0,
|
||||
SequenceNum: seq,
|
||||
}
|
||||
var buffer bytes.Buffer
|
||||
binary.Write(&buffer, binary.BigEndian, icmp)
|
||||
icmp.CheckSum = checkSum(buffer.Bytes())
|
||||
buffer.Reset()
|
||||
|
||||
return icmp
|
||||
}
|
||||
|
||||
func sendICMPRequest(icmp ICMP, destAddr *net.IPAddr, timeout time.Duration) (PingResult, error) {
|
||||
var res PingResult
|
||||
res.RemoteIP = destAddr.String()
|
||||
conn, err := net.DialIP("ip:icmp", nil, destAddr)
|
||||
if err != nil {
|
||||
return res, err
|
||||
}
|
||||
defer conn.Close()
|
||||
var buffer bytes.Buffer
|
||||
binary.Write(&buffer, binary.BigEndian, icmp)
|
||||
|
||||
if _, err := conn.Write(buffer.Bytes()); err != nil {
|
||||
return res, err
|
||||
}
|
||||
|
||||
tStart := time.Now()
|
||||
|
||||
conn.SetReadDeadline((time.Now().Add(timeout)))
|
||||
|
||||
recv := make([]byte, 1024)
|
||||
res.RecvCount, err = conn.Read(recv)
|
||||
|
||||
if err != nil {
|
||||
return res, err
|
||||
}
|
||||
|
||||
tEnd := time.Now()
|
||||
res.Duration = tEnd.Sub(tStart)
|
||||
|
||||
return res, err
|
||||
}
|
||||
|
||||
func checkSum(data []byte) uint16 {
|
||||
var (
|
||||
sum uint32
|
||||
length int = len(data)
|
||||
index int
|
||||
)
|
||||
for length > 1 {
|
||||
sum += uint32(data[index])<<8 + uint32(data[index+1])
|
||||
index += 2
|
||||
length -= 2
|
||||
}
|
||||
if length > 0 {
|
||||
sum += uint32(data[index])
|
||||
}
|
||||
sum += (sum >> 16)
|
||||
|
||||
return uint16(^sum)
|
||||
}
|
||||
|
||||
type PingResult struct {
|
||||
Duration time.Duration
|
||||
RecvCount int
|
||||
RemoteIP string
|
||||
}
|
||||
|
||||
func Ping(ip string, seq int, timeout time.Duration) (PingResult, error) {
|
||||
var res PingResult
|
||||
ipAddr, err := net.ResolveIPAddr("ip", ip)
|
||||
if err != nil {
|
||||
return res, err
|
||||
}
|
||||
icmp := getICMP(uint16(seq))
|
||||
return sendICMPRequest(icmp, ipAddr, timeout)
|
||||
}
|
||||
|
||||
func IsIpPingable(ip string, timeout time.Duration, retryLimit int) bool {
|
||||
for i := 0; i < retryLimit; i++ {
|
||||
_, err := Ping(ip, 29, timeout)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
@@ -0,0 +1,15 @@
|
||||
package starnet
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func Test_Ping(t *testing.T) {
|
||||
fmt.Println(Ping("baidu.com", 29, time.Second*2))
|
||||
fmt.Println(Ping("www.b612.me", 29, time.Second*2))
|
||||
fmt.Println(IsIpPingable("baidu.com", time.Second*2, 3))
|
||||
fmt.Println(IsIpPingable("www.b612.me", time.Second*2, 3))
|
||||
|
||||
}
|
||||
@@ -1,295 +0,0 @@
|
||||
package starnet
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// 识别头
|
||||
var header = []byte{11, 27, 19, 96, 12, 25, 02, 20}
|
||||
|
||||
// MsgQueue 为基本的信息单位
|
||||
type MsgQueue struct {
|
||||
ID uint16
|
||||
Msg []byte
|
||||
Conn interface{}
|
||||
}
|
||||
|
||||
// StarQueue 为流数据中的消息队列分发
|
||||
type StarQueue struct {
|
||||
Encode bool
|
||||
Reserve uint16
|
||||
Msgid uint16
|
||||
MsgPool []MsgQueue
|
||||
UnFinMsg sync.Map
|
||||
LastID int //= -1
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
duration time.Duration
|
||||
EncodeFunc func([]byte) []byte
|
||||
DecodeFunc func([]byte) []byte
|
||||
//parseMu sync.Mutex
|
||||
restoreMu sync.Mutex
|
||||
}
|
||||
|
||||
// NewQueue 建立一个新消息队列
|
||||
func NewQueue() *StarQueue {
|
||||
var que StarQueue
|
||||
que.Encode = false
|
||||
que.ctx, que.cancel = context.WithCancel(context.Background())
|
||||
que.duration = 0
|
||||
return &que
|
||||
}
|
||||
|
||||
// Uint32ToByte 4位uint32转byte
|
||||
func Uint32ToByte(src uint32) []byte {
|
||||
res := make([]byte, 4)
|
||||
res[3] = uint8(src)
|
||||
res[2] = uint8(src >> 8)
|
||||
res[1] = uint8(src >> 16)
|
||||
res[0] = uint8(src >> 24)
|
||||
return res
|
||||
}
|
||||
|
||||
// ByteToUint32 byte转4位uint32
|
||||
func ByteToUint32(src []byte) uint32 {
|
||||
var res uint32
|
||||
buffer := bytes.NewBuffer(src)
|
||||
binary.Read(buffer, binary.BigEndian, &res)
|
||||
return res
|
||||
}
|
||||
|
||||
// Uint16ToByte 2位uint16转byte
|
||||
func Uint16ToByte(src uint16) []byte {
|
||||
res := make([]byte, 2)
|
||||
res[1] = uint8(src)
|
||||
res[0] = uint8(src >> 8)
|
||||
return res
|
||||
}
|
||||
|
||||
// ByteToUint16 用于byte转uint16
|
||||
func ByteToUint16(src []byte) uint16 {
|
||||
var res uint16
|
||||
buffer := bytes.NewBuffer(src)
|
||||
binary.Read(buffer, binary.BigEndian, &res)
|
||||
return res
|
||||
}
|
||||
|
||||
// BuildMessage 生成编码后的信息用于发送
|
||||
func (que *StarQueue) BuildMessage(src []byte) []byte {
|
||||
var buff bytes.Buffer
|
||||
que.Msgid++
|
||||
if que.Encode {
|
||||
src = que.EncodeFunc(src)
|
||||
}
|
||||
length := uint32(len(src))
|
||||
buff.Write(header)
|
||||
buff.Write(Uint32ToByte(length))
|
||||
buff.Write(Uint16ToByte(que.Msgid))
|
||||
buff.Write(src)
|
||||
return buff.Bytes()
|
||||
}
|
||||
|
||||
// BuildHeader 生成编码后的Header用于发送
|
||||
func (que *StarQueue) BuildHeader(length uint32) []byte {
|
||||
var buff bytes.Buffer
|
||||
que.Msgid++
|
||||
buff.Write(header)
|
||||
buff.Write(Uint32ToByte(length))
|
||||
buff.Write(Uint16ToByte(que.Msgid))
|
||||
return buff.Bytes()
|
||||
}
|
||||
|
||||
type unFinMsg struct {
|
||||
ID uint16
|
||||
LengthRecv uint32
|
||||
// HeaderMsg 信息头,应当为14位:8位识别码+4位长度码+2位id
|
||||
HeaderMsg []byte
|
||||
RecvMsg []byte
|
||||
}
|
||||
|
||||
// ParseMessage 用于解析收到的msg信息
|
||||
func (que *StarQueue) ParseMessage(msg []byte, conn interface{}) error {
|
||||
tmp, ok := que.UnFinMsg.Load(conn)
|
||||
if ok { //存在未完成的信息
|
||||
lastMsg := tmp.(*unFinMsg)
|
||||
headerLen := len(lastMsg.HeaderMsg)
|
||||
if headerLen < 14 { //未完成头标题
|
||||
//传输的数据不能填充header头
|
||||
if len(msg) < 14-headerLen {
|
||||
//加入header头并退出
|
||||
lastMsg.HeaderMsg = bytesMerge(lastMsg.HeaderMsg, msg)
|
||||
que.UnFinMsg.Store(conn, lastMsg)
|
||||
return nil
|
||||
}
|
||||
//获取14字节完整的header
|
||||
header := msg[0 : 14-headerLen]
|
||||
lastMsg.HeaderMsg = bytesMerge(lastMsg.HeaderMsg, header)
|
||||
//检查收到的header是否为认证header
|
||||
//若不是,丢弃并重新来过
|
||||
if !checkHeader(lastMsg.HeaderMsg[0:8]) {
|
||||
que.UnFinMsg.Delete(conn)
|
||||
if len(msg) == 0 {
|
||||
return nil
|
||||
}
|
||||
return que.ParseMessage(msg, conn)
|
||||
}
|
||||
//获得本数据包长度
|
||||
lastMsg.LengthRecv = ByteToUint32(lastMsg.HeaderMsg[8:12])
|
||||
//获得本数据包ID
|
||||
lastMsg.ID = ByteToUint16(lastMsg.HeaderMsg[12:14])
|
||||
//存入列表
|
||||
que.UnFinMsg.Store(conn, lastMsg)
|
||||
msg = msg[14-headerLen:]
|
||||
if uint32(len(msg)) < lastMsg.LengthRecv {
|
||||
lastMsg.RecvMsg = msg
|
||||
que.UnFinMsg.Store(conn, lastMsg)
|
||||
return nil
|
||||
}
|
||||
if uint32(len(msg)) >= lastMsg.LengthRecv {
|
||||
lastMsg.RecvMsg = msg[0:lastMsg.LengthRecv]
|
||||
if que.Encode {
|
||||
lastMsg.RecvMsg = que.DecodeFunc(lastMsg.RecvMsg)
|
||||
}
|
||||
msg = msg[lastMsg.LengthRecv:]
|
||||
stroeMsg := MsgQueue{
|
||||
ID: lastMsg.ID,
|
||||
Msg: lastMsg.RecvMsg,
|
||||
Conn: conn,
|
||||
}
|
||||
que.MsgPool = append(que.MsgPool, stroeMsg)
|
||||
que.UnFinMsg.Delete(conn)
|
||||
return que.ParseMessage(msg, conn)
|
||||
}
|
||||
} else {
|
||||
lastID := int(lastMsg.LengthRecv) - len(lastMsg.RecvMsg)
|
||||
if lastID < 0 {
|
||||
que.UnFinMsg.Delete(conn)
|
||||
return que.ParseMessage(msg, conn)
|
||||
}
|
||||
if len(msg) >= lastID {
|
||||
lastMsg.RecvMsg = bytesMerge(lastMsg.RecvMsg, msg[0:lastID])
|
||||
if que.Encode {
|
||||
lastMsg.RecvMsg = que.DecodeFunc(lastMsg.RecvMsg)
|
||||
}
|
||||
stroeMsg := MsgQueue{
|
||||
ID: lastMsg.ID,
|
||||
Msg: lastMsg.RecvMsg,
|
||||
Conn: conn,
|
||||
}
|
||||
que.MsgPool = append(que.MsgPool, stroeMsg)
|
||||
que.UnFinMsg.Delete(conn)
|
||||
if len(msg) == lastID {
|
||||
return nil
|
||||
}
|
||||
msg = msg[lastID:]
|
||||
return que.ParseMessage(msg, conn)
|
||||
}
|
||||
lastMsg.RecvMsg = bytesMerge(lastMsg.RecvMsg, msg)
|
||||
que.UnFinMsg.Store(conn, lastMsg)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
if len(msg) == 0 {
|
||||
return nil
|
||||
}
|
||||
var start int
|
||||
if start = searchHeader(msg); start == -1 {
|
||||
return errors.New("data format error")
|
||||
}
|
||||
msg = msg[start:]
|
||||
lastMsg := unFinMsg{}
|
||||
que.UnFinMsg.Store(conn, &lastMsg)
|
||||
return que.ParseMessage(msg, conn)
|
||||
}
|
||||
|
||||
func checkHeader(msg []byte) bool {
|
||||
if len(msg) != 8 {
|
||||
return false
|
||||
}
|
||||
for k, v := range msg {
|
||||
if v != header[k] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func searchHeader(msg []byte) int {
|
||||
if len(msg) < 8 {
|
||||
return 0
|
||||
}
|
||||
for k, v := range msg {
|
||||
find := 0
|
||||
if v == header[0] {
|
||||
for k2, v2 := range header {
|
||||
if msg[k+k2] == v2 {
|
||||
find++
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
if find == 8 {
|
||||
return k
|
||||
}
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
func bytesMerge(src ...[]byte) []byte {
|
||||
var buff bytes.Buffer
|
||||
for _, v := range src {
|
||||
buff.Write(v)
|
||||
}
|
||||
return buff.Bytes()
|
||||
}
|
||||
|
||||
// Restore 获取收到的信息
|
||||
func (que *StarQueue) Restore(n int) ([]MsgQueue, error) {
|
||||
que.restoreMu.Lock()
|
||||
defer que.restoreMu.Unlock()
|
||||
var res []MsgQueue
|
||||
dura := time.Duration(0)
|
||||
for len(que.MsgPool) < n {
|
||||
select {
|
||||
case <-que.ctx.Done():
|
||||
return res, errors.New("Stoped By External Function Call")
|
||||
default:
|
||||
time.Sleep(time.Millisecond * 20)
|
||||
dura = time.Millisecond*20 + dura
|
||||
if que.duration != 0 && dura > que.duration {
|
||||
return res, errors.New("Time Exceed")
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(que.MsgPool) < n {
|
||||
return res, errors.New("Result Not Enough")
|
||||
}
|
||||
res = que.MsgPool[0:n]
|
||||
que.MsgPool = que.MsgPool[n:]
|
||||
return res, nil
|
||||
}
|
||||
|
||||
// RestoreOne 获取收到的一个信息
|
||||
func (que *StarQueue) RestoreOne() (MsgQueue, error) {
|
||||
data, err := que.Restore(1)
|
||||
if len(data) == 1 {
|
||||
return data[0], err
|
||||
}
|
||||
return MsgQueue{}, err
|
||||
}
|
||||
|
||||
// Stop 立即停止Restore
|
||||
func (que *StarQueue) Stop() {
|
||||
que.cancel()
|
||||
}
|
||||
|
||||
// RestoreDuration Restore最大超时时间
|
||||
func (que *StarQueue) RestoreDuration(tm time.Duration) {
|
||||
que.duration = tm
|
||||
}
|
||||
+401
@@ -0,0 +1,401 @@
|
||||
package starnet
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type myConn struct {
|
||||
reader io.Reader
|
||||
conn net.Conn
|
||||
isReadOnly bool
|
||||
multiReader io.Reader
|
||||
}
|
||||
|
||||
func (c *myConn) Read(p []byte) (int, error) {
|
||||
if c.isReadOnly {
|
||||
return c.reader.Read(p)
|
||||
}
|
||||
if c.multiReader == nil {
|
||||
c.multiReader = io.MultiReader(c.reader, c.conn)
|
||||
}
|
||||
return c.multiReader.Read(p)
|
||||
}
|
||||
|
||||
func (c *myConn) Write(p []byte) (int, error) {
|
||||
if c.isReadOnly {
|
||||
return 0, io.ErrClosedPipe
|
||||
}
|
||||
return c.conn.Write(p)
|
||||
}
|
||||
func (c *myConn) Close() error {
|
||||
if c.isReadOnly {
|
||||
return nil
|
||||
}
|
||||
return c.conn.Close()
|
||||
}
|
||||
func (c *myConn) LocalAddr() net.Addr {
|
||||
if c.isReadOnly {
|
||||
return nil
|
||||
}
|
||||
return c.conn.LocalAddr()
|
||||
}
|
||||
func (c *myConn) RemoteAddr() net.Addr {
|
||||
if c.isReadOnly {
|
||||
return nil
|
||||
}
|
||||
return c.conn.RemoteAddr()
|
||||
}
|
||||
func (c *myConn) SetDeadline(t time.Time) error {
|
||||
if c.isReadOnly {
|
||||
return nil
|
||||
}
|
||||
return c.conn.SetDeadline(t)
|
||||
}
|
||||
func (c *myConn) SetReadDeadline(t time.Time) error {
|
||||
if c.isReadOnly {
|
||||
return nil
|
||||
}
|
||||
return c.conn.SetReadDeadline(t)
|
||||
}
|
||||
func (c *myConn) SetWriteDeadline(t time.Time) error {
|
||||
if c.isReadOnly {
|
||||
return nil
|
||||
}
|
||||
return c.conn.SetWriteDeadline(t)
|
||||
}
|
||||
|
||||
type Listener struct {
|
||||
net.Listener
|
||||
cfg *tls.Config
|
||||
getConfigForClient func(hostname string) *tls.Config
|
||||
allowNonTls bool
|
||||
}
|
||||
|
||||
func (l *Listener) GetConfigForClient() func(hostname string) *tls.Config {
|
||||
return l.getConfigForClient
|
||||
}
|
||||
|
||||
func (l *Listener) SetConfigForClient(getConfigForClient func(hostname string) *tls.Config) {
|
||||
l.getConfigForClient = getConfigForClient
|
||||
}
|
||||
|
||||
func Listen(network, address string) (*Listener, error) {
|
||||
listener, err := net.Listen(network, address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Listener{Listener: listener}, nil
|
||||
}
|
||||
|
||||
func ListenTLSWithListenConfig(liscfg net.ListenConfig, network, address string, config *tls.Config, getConfigForClient func(hostname string) *tls.Config, allowNonTls bool) (*Listener, error) {
|
||||
listener, err := liscfg.Listen(context.Background(), network, address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Listener{
|
||||
Listener: listener,
|
||||
cfg: config,
|
||||
getConfigForClient: getConfigForClient,
|
||||
allowNonTls: allowNonTls,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func ListenWithListener(listener net.Listener, config *tls.Config, getConfigForClient func(hostname string) *tls.Config, allowNonTls bool) (*Listener, error) {
|
||||
return &Listener{
|
||||
Listener: listener,
|
||||
cfg: config,
|
||||
getConfigForClient: getConfigForClient,
|
||||
allowNonTls: allowNonTls,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func ListenTLSWithConfig(network, address string, config *tls.Config, getConfigForClient func(hostname string) *tls.Config, allowNonTls bool) (*Listener, error) {
|
||||
listener, err := net.Listen(network, address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Listener{
|
||||
Listener: listener,
|
||||
cfg: config,
|
||||
getConfigForClient: getConfigForClient,
|
||||
allowNonTls: allowNonTls,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func ListenTLS(network, address string, certFile, keyFile string, allowNonTls bool) (*Listener, error) {
|
||||
config, err := tls.LoadX509KeyPair(certFile, keyFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tlsConfig := &tls.Config{
|
||||
Certificates: []tls.Certificate{config},
|
||||
}
|
||||
|
||||
listener, err := net.Listen(network, address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Listener{
|
||||
Listener: listener,
|
||||
cfg: tlsConfig,
|
||||
allowNonTls: allowNonTls,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (l *Listener) Accept() (net.Conn, error) {
|
||||
conn, err := l.Listener.Accept()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Conn{
|
||||
Conn: conn,
|
||||
tlsCfg: l.cfg,
|
||||
getConfigForClient: l.getConfigForClient,
|
||||
allowNonTls: l.allowNonTls,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type Conn struct {
|
||||
net.Conn
|
||||
once sync.Once
|
||||
initErr error
|
||||
isTLS bool
|
||||
tlsCfg *tls.Config
|
||||
tlsConn *tls.Conn
|
||||
buffer *bytes.Buffer
|
||||
noTlsReader io.Reader
|
||||
isOriginal bool
|
||||
getConfigForClient func(hostname string) *tls.Config
|
||||
hostname string
|
||||
allowNonTls bool
|
||||
}
|
||||
|
||||
func (c *Conn) Hostname() string {
|
||||
if c.hostname != "" {
|
||||
return c.hostname
|
||||
}
|
||||
if c.isTLS && c.tlsConn != nil {
|
||||
if c.tlsConn.ConnectionState().ServerName != "" {
|
||||
c.hostname = c.tlsConn.ConnectionState().ServerName
|
||||
return c.hostname
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (c *Conn) IsTLS() bool {
|
||||
return c.isTLS
|
||||
}
|
||||
|
||||
func (c *Conn) TlsConn() *tls.Conn {
|
||||
return c.tlsConn
|
||||
}
|
||||
|
||||
func (c *Conn) isTLSConnection() (bool, error) {
|
||||
if c.getConfigForClient == nil {
|
||||
peek := make([]byte, 5)
|
||||
n, err := io.ReadFull(c.Conn, peek)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
isTLS := n >= 3 && peek[0] == 0x16 && peek[1] == 0x03
|
||||
|
||||
c.buffer = bytes.NewBuffer(peek[:n])
|
||||
return isTLS, nil
|
||||
}
|
||||
|
||||
c.buffer = new(bytes.Buffer)
|
||||
r := io.TeeReader(c.Conn, c.buffer)
|
||||
var hello *tls.ClientHelloInfo
|
||||
tls.Server(&myConn{reader: r, isReadOnly: true}, &tls.Config{
|
||||
GetConfigForClient: func(argHello *tls.ClientHelloInfo) (*tls.Config, error) {
|
||||
hello = new(tls.ClientHelloInfo)
|
||||
*hello = *argHello
|
||||
return nil, nil
|
||||
},
|
||||
}).Handshake()
|
||||
peek := c.buffer.Bytes()
|
||||
n := len(peek)
|
||||
isTLS := n >= 3 && peek[0] == 0x16 && peek[1] == 0x03
|
||||
if hello == nil {
|
||||
return isTLS, nil
|
||||
}
|
||||
c.hostname = hello.ServerName
|
||||
if c.hostname == "" {
|
||||
c.hostname, _, _ = net.SplitHostPort(c.Conn.LocalAddr().String())
|
||||
}
|
||||
return isTLS, nil
|
||||
}
|
||||
|
||||
func (c *Conn) init() {
|
||||
c.once.Do(func() {
|
||||
if c.isOriginal {
|
||||
return
|
||||
}
|
||||
if c.tlsCfg != nil {
|
||||
isTLS, err := c.isTLSConnection()
|
||||
if err != nil {
|
||||
c.initErr = err
|
||||
return
|
||||
}
|
||||
c.isTLS = isTLS
|
||||
}
|
||||
|
||||
if c.isTLS {
|
||||
var cfg = c.tlsCfg
|
||||
if c.getConfigForClient != nil {
|
||||
cfg = c.getConfigForClient(c.hostname)
|
||||
if cfg == nil {
|
||||
cfg = c.tlsCfg
|
||||
}
|
||||
}
|
||||
c.tlsConn = tls.Server(&myConn{
|
||||
reader: c.buffer,
|
||||
conn: c.Conn,
|
||||
isReadOnly: false,
|
||||
}, cfg)
|
||||
} else {
|
||||
if !c.allowNonTls {
|
||||
c.initErr = net.ErrClosed
|
||||
return
|
||||
}
|
||||
c.noTlsReader = io.MultiReader(c.buffer, c.Conn)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (c *Conn) Read(b []byte) (int, error) {
|
||||
c.init()
|
||||
if c.initErr != nil {
|
||||
return 0, c.initErr
|
||||
}
|
||||
if c.isTLS {
|
||||
return c.tlsConn.Read(b)
|
||||
}
|
||||
return c.noTlsReader.Read(b)
|
||||
}
|
||||
|
||||
func (c *Conn) Write(b []byte) (int, error) {
|
||||
c.init()
|
||||
if c.initErr != nil {
|
||||
return 0, c.initErr
|
||||
}
|
||||
|
||||
if c.isTLS {
|
||||
return c.tlsConn.Write(b)
|
||||
}
|
||||
return c.Conn.Write(b)
|
||||
}
|
||||
|
||||
func (c *Conn) Close() error {
|
||||
if c.isTLS && c.tlsConn != nil {
|
||||
return c.tlsConn.Close()
|
||||
}
|
||||
return c.Conn.Close()
|
||||
}
|
||||
|
||||
func (c *Conn) SetDeadline(t time.Time) error {
|
||||
if c.isTLS && c.tlsConn != nil {
|
||||
return c.tlsConn.SetDeadline(t)
|
||||
}
|
||||
return c.Conn.SetDeadline(t)
|
||||
}
|
||||
|
||||
func (c *Conn) SetReadDeadline(t time.Time) error {
|
||||
if c.isTLS && c.tlsConn != nil {
|
||||
return c.tlsConn.SetReadDeadline(t)
|
||||
}
|
||||
return c.Conn.SetReadDeadline(t)
|
||||
}
|
||||
|
||||
func (c *Conn) SetWriteDeadline(t time.Time) error {
|
||||
if c.isTLS && c.tlsConn != nil {
|
||||
return c.tlsConn.SetWriteDeadline(t)
|
||||
}
|
||||
return c.Conn.SetWriteDeadline(t)
|
||||
}
|
||||
|
||||
func (c *Conn) TlsConnection() (*tls.Conn, error) {
|
||||
if c.initErr != nil {
|
||||
return nil, c.initErr
|
||||
}
|
||||
if !c.isTLS {
|
||||
return nil, net.ErrClosed
|
||||
}
|
||||
return c.tlsConn, nil
|
||||
}
|
||||
|
||||
func (c *Conn) OriginalConn() net.Conn {
|
||||
return c.Conn
|
||||
}
|
||||
|
||||
func NewClientTlsConn(conn net.Conn, cfg *tls.Config) (*Conn, error) {
|
||||
if conn == nil {
|
||||
return nil, net.ErrClosed
|
||||
}
|
||||
c := &Conn{
|
||||
Conn: conn,
|
||||
isTLS: true,
|
||||
tlsCfg: cfg,
|
||||
tlsConn: tls.Client(conn, cfg),
|
||||
isOriginal: true,
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func NewServerTlsConn(conn net.Conn, cfg *tls.Config) (*Conn, error) {
|
||||
if conn == nil {
|
||||
return nil, net.ErrClosed
|
||||
}
|
||||
c := &Conn{
|
||||
Conn: conn,
|
||||
isTLS: true,
|
||||
tlsCfg: cfg,
|
||||
tlsConn: tls.Server(conn, cfg),
|
||||
isOriginal: true,
|
||||
}
|
||||
c.init()
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func Dial(network, address string) (*Conn, error) {
|
||||
conn, err := net.Dial(network, address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Conn{
|
||||
Conn: conn,
|
||||
isTLS: false,
|
||||
tlsCfg: nil,
|
||||
tlsConn: nil,
|
||||
noTlsReader: conn,
|
||||
isOriginal: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func DialTLS(network, address string, certFile, keyFile string) (*Conn, error) {
|
||||
config, err := tls.LoadX509KeyPair(certFile, keyFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tlsConfig := &tls.Config{
|
||||
Certificates: []tls.Certificate{config},
|
||||
}
|
||||
|
||||
conn, err := net.Dial(network, address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return NewClientTlsConn(conn, tlsConfig)
|
||||
}
|
||||
Reference in New Issue
Block a user