feat: 完善 staros 系统能力并更新 wincmd 发布版依赖

- 重构 sysconf 为文档模型 INI Parser 与 Config Framework
- 强化 hosts 解析、插入校验、写回与异常输入处理
- 完善 StarCmd 生命周期、等待 API、流式输出与 IO 重定向
- 扩展跨平台文件时间、文件锁、内存、进程与网络能力
- 将 Windows 进程适配更新到 b612.me/wincmd v0.1.0
- 移除本地 wincmd/win32api replace,改用发布版依赖
- 将最低 Go 版本提升到 1.18
- 补充 hosts、sysconf、FileLock、StarCmd 与平台适配回归测试
This commit is contained in:
兔子 2026-06-09 18:10:19 +08:00
parent 0b6373e9e3
commit d93a851d1b
Signed by: b612
GPG Key ID: 99DD2222B612B612
44 changed files with 9774 additions and 2260 deletions

119
README.md Normal file
View File

@ -0,0 +1,119 @@
# staros
`staros` is a cgo-free Go package for small cross-platform OS utilities.
The package keeps compatibility with existing APIs, but platform-dependent functions should prefer explicit error-returning variants where available. Unsupported platform implementations return `ERR_UNSUPPORTED` instead of silently pretending to work.
## Go Version
`go.mod` declares `go 1.18`. This release targets Go 1.18 or newer and no longer promises Go 1.16/1.17 compatibility. The release gate includes the current local Go toolchain plus cross-platform compile checks.
## Platform Support
| Area | Linux | Windows | Darwin |
| --- | --- | --- | --- |
| Basic path checks: `Exists`, `IsFile`, `IsFolder` | Supported | Supported | Supported |
| File locks: `FileLock` | `flock` | Win32 lock API | `flock` |
| File timestamps: `GetFileCreationTime`, `GetFileAccessTime`, `SetFileTimesE` | ctime/access/modtime via `x/sys/unix` | creation/access/modtime via `x/sys/windows` | birth/access/modtime via stdlib/syscall |
| Memory: `Memory` | `/proc` and `syscall.Sysinfo` | Win32 `GlobalMemoryStatusEx` | `vm_stat` and `sysctl` |
| Disk: `DiskUsageE` | `Statfs` | `GetDiskFreeSpaceExW` | `Statfs` |
| OS identity: `IsRoot`, `Whoami` | Supported | `IsRoot` supported, `Whoami` returns `ERR_UNSUPPORTED` | Supported |
| CPU usage: `CpuUsage`, `CpuUsageByPid` | `/proc/stat` and `/proc/<pid>/stat` | Stub returns `0` | Stub returns `0` |
| Process query: `FindProcess*` | `/proc` backed | `wincmd` backed basic fields | Returns `ERR_UNSUPPORTED` |
| Process launch: `Command`, `CommandContext`, `Start` | Supported | Supported | Supported |
| Process lifecycle: `ReleaseE` | Starts through normal lifecycle with `Setsid` | Starts through normal lifecycle | Returns `ERR_UNSUPPORTED` |
| Process detach: `DetachE` | Start-before-Start only, then `Process.Release` | Start-before-Start only, then `Process.Release` | Returns `ERR_UNSUPPORTED` |
| Run as user: `SetRunUserE`, `DaemonWithUser` | Supported | Returns `ERR_UNSUPPORTED` | Returns `ERR_UNSUPPORTED` |
| Keep capabilities: `SetKeepCaps`, `StarCmd.SetKeepCaps` | Package helper uses Linux `prctl`; command helper preserves current caps via `AmbientCaps` | Returns `ERR_UNSUPPORTED` | Returns `ERR_UNSUPPORTED` |
| Network adapters/speeds/connections | `/proc/net` backed | IP Helper adapter counters and TCP/UDP owner-pid tables; no Unix socket/inode fields | Returns `ERR_UNSUPPORTED` |
| Beep | PC speaker or terminal bell fallback | Win32 `Beep` | `osascript` or terminal bell fallback |
## API Semantics
- `ERR_UNSUPPORTED` means the symbol exists for build compatibility but the current OS implementation is intentionally unavailable.
- `ReleaseE` preserves `StarCmd` lifecycle observation: `Stopped()` still closes after the process exits and `ExitCode()` is populated when available. The historical misspelled `Stoped()` method remains as a deprecated compatibility alias.
- `DetachE` is a true detached start path and must be called before `Start`; calling it after `Start` returns an already-started error to avoid racing `Process.Release()` with the internal `Wait()`.
- `Wait`, `WaitContext`, and `WaitTimeout` provide explicit lifecycle wait helpers that return the final process wait error. `Stopped()` remains available when callers only need a close signal.
- On Linux, the package-level `SetKeepCaps()` helper applies `prctl(PR_SET_KEEPCAPS)` to the current process. `StarCmd.SetKeepCaps()` is different: it snapshots the current capability set and configures the child command's `SysProcAttr.AmbientCaps` before `Start`.
- `StdoutChan`, `StderrChan`, and `OutputChan` provide best-effort streaming observation for future output chunks. They close with `Stopped()` and do not replace the existing full-output capture methods.
- `RedirectStdout`, `RedirectStderr`, `RedirectOutput`, and `RedirectStdin` configure process IO before `Start`. File helpers such as `RedirectStdoutFile` and `RedirectStdinFile` open the file and close it after the process reaches its final state.
- `WriteStdinE` and `WriteStdinStringE` write raw stdin data without appending a newline. `WriteStdinLineE` and the legacy `WriteCmdE` append one newline.
- Legacy methods that do not return errors are kept for compatibility; prefer the `*E` variants for new code.
## sysconf Migration
The `sysconf` package now uses a document-backed INI model instead of the old `SysConf` struct shape. This is a breaking API change: callers should migrate to `NewIni`, `NewLinuxConf`, `Document`, `Section`, and `Entry` directly instead of relying on a field-compatible wrapper.
Minimal migration rules:
- Replace parser setup through `SysConf` fields with constructors: use `sysconf.NewIni()` for sectioned INI files, or `sysconf.NewLinuxConf(equal)` for flat Linux-style config files.
- Replace direct segment/key data mutation with section methods: `ini.Section(name)`, `ini.Set(section, key, value)`, `sec.Set`, `sec.SetAll`, `sec.AddValue`, `sec.Delete`, and `ini.DeleteSection`.
- Replace single-value assumptions with duplicate-aware reads where needed: `ini.Get`/`sec.Get` returns the first value, while `ini.GetAll`/`sec.GetAll` returns all repeated keys.
- Replace manual struct binding with `ini.Unmarshal(&dst)` and `ini.Marshal(src)` using `seg` and `key` tags.
- Use `ini.Build()` or `ini.Save(path)` for write-back. Unchanged parsed lines keep their original formatting, while changed entries are rebuilt from the new model.
Example:
```go
ini := sysconf.NewIni()
if err := ini.Parse(data); err != nil {
return err
}
app := ini.Section("app")
if app == nil {
app = ini.AddSection("app")
}
_ = app.SetInt("port", 9090, "")
_ = app.SetAll("feature", []string{"stable", "audit"}, "")
out := ini.Build()
```
INI parser capabilities:
- `NewIni()` parses common sectioned INI files with `=` and `:` key/value delimiters, `#` and `;` comments, section header comments, quoted values, no-value keys, duplicate keys, duplicate sections, and backslash line continuation.
- Inline comments require whitespace before the comment marker, so values such as URLs or fragments containing `#` are not truncated accidentally.
- `NewIniWithProfiles(...)`, `StrictINIProfile()`, and `LinuxConfProfile(equal)` let callers pin parser behavior for strict sectioned INI or flat Linux-style config files without mutating parser fields ad hoc.
- `Document.Strict` can be enabled when callers want malformed input to return a `ParseError` with line and column information instead of preserving unknown lines as raw content.
- Write-back is lossless for unchanged parsed lines. Changed values that would otherwise be misread as comments, leading/trailing whitespace, tabs, or newlines are emitted as quoted values.
Config framework capabilities:
- `sysconf.NewConfig()`, `sysconf.LoadConfig(&dst, files, ...)`, and `sysconf.LoadConfigSources(&dst, sources, ...)` load one or more INI sources in order; later sources override earlier values for the same section/key.
- `sysconf.RequiredFile(path)` and `sysconf.OptionalFile(path)` declare whether a missing file should fail loading or be skipped. `sysconf.BytesSource(name, data)` and `sysconf.StringSource(name, data)` support in-memory overlays for tests, generated defaults, and embedded configs.
- `Config` exposes direct access and write-back helpers: `Get`, `GetAll`, error-returning typed getters such as `GetIntE` / `GetBoolE` / `GetDurationE`, `Has`, `Set`, `SetAll`, `Delete`, `Build`, `Save`, and `SaveAtomic`.
- Struct binding uses `seg`, `key`, `default`, `env`, `split`, and `required` tags. Nested structs inherit their parent `seg` tag, while `env:"-"` disables environment overrides for a field.
- Environment overrides are opt-in through `WithEnvPrefix` / `WithEnvLookup`; generated names normalize section and key names to uppercase underscore form, such as `APP_SERVER_PORT`.
- Binding supports strings, bools, signed/unsigned integers, floats, `time.Duration`, `encoding.TextUnmarshaler`, slices, arrays, and `map[string]T` for scalar `T`.
- Repeated INI keys bind naturally to slices. When a single value should expand into multiple collection items, add an explicit `split` tag such as `split:","`, `split:"|"`, or `split:"csv"`.
- `Config.SetStruct(src)` writes a config struct back into the current document, using repeated keys for slices/arrays and sorted `key=value` repeated entries for maps.
- `sysconf.DescribeConfig(src)` exports struct tag metadata as `ConfigFieldInfo` records, and `sysconf.SampleConfig(src)` builds a sample INI from defaults, current struct values, required placeholders, or type zero values without mutating the source struct.
- `Config.SectionNames()`, `Config.Keys(section)`, `Config.Flatten()`, and `Config.FlattenEntries()` expose sorted section/key discovery, duplicate-aware flattened values, and structured section/key/value entries for diagnostics, tests, and lightweight config export.
- `ConfigError` and `ConfigSourceError` include structured metadata and unwrap their underlying parse, file, or conversion errors for `errors.Is` / `errors.As`.
- A config struct can implement `Validate() error`; validation runs after defaults, file values, env overrides, required checks, and type binding.
Example:
```go
type AppConfig struct {
App struct {
Name string `key:"name" required:"true"`
Port int `key:"port" default:"8080"`
Timeout time.Duration `key:"timeout" default:"5s"`
Tags []string `key:"tag" env:"APP_TAGS"`
} `seg:"app"`
}
var cfg AppConfig
_, err := sysconf.LoadConfigSources(&cfg, []sysconf.ConfigSource{
sysconf.RequiredFile("/etc/app.ini"),
sysconf.OptionalFile("/etc/app.local.ini"),
}, sysconf.WithEnvPrefix("APP"))
```
The current framework is intentionally local-file focused. It does not include hot reload, remote configuration centers, secret managers, or a schema DSL.
## Scope
This package is intentionally small and cgo-free. Functionality that duplicates broad system inventory packages should stay frozen unless it improves cross-platform semantics, error observability, or compatibility for existing callers.

View File

@ -7,6 +7,12 @@ import (
var ERR_ALREADY_LOCKED = errors.New("ALREADY LOCKED") var ERR_ALREADY_LOCKED = errors.New("ALREADY LOCKED")
var ERR_TIMEOUT = errors.New("TIME OUT") var ERR_TIMEOUT = errors.New("TIME OUT")
var ERR_UNSUPPORTED = errors.New("UNSUPPORTED")
var errNilFile = errors.New("nil file")
var errNilFileInfo = errors.New("nil file info")
var errUnsupportedFileInfo = errors.New("unsupported file info")
var errFileLockNotLocked = errors.New("file lock is not locked")
func NewFileLock(filepath string) FileLock { func NewFileLock(filepath string) FileLock {
return FileLock{ return FileLock{
@ -24,7 +30,7 @@ func Exists(path string) bool {
} }
// IsFile 返回给定文件地址是否是一个文件, // IsFile 返回给定文件地址是否是一个文件,
//True为是一个文件,False为不是文件或路径无效 // True为是一个文件,False为不是文件或路径无效
func IsFile(fpath string) bool { func IsFile(fpath string) bool {
s, err := os.Stat(fpath) s, err := os.Stat(fpath)
if err != nil { if err != nil {
@ -34,7 +40,7 @@ func IsFile(fpath string) bool {
} }
// IsFolder 返回给定文件地址是否是一个文件夹, // IsFolder 返回给定文件地址是否是一个文件夹,
//True为是一个文件夹,False为不是文件夹或路径无效 // True为是一个文件夹,False为不是文件夹或路径无效
func IsFolder(fpath string) bool { func IsFolder(fpath string) bool {
s, err := os.Stat(fpath) s, err := os.Stat(fpath)
if err != nil { if err != nil {

View File

@ -1,9 +1,9 @@
//+build darwin //go:build darwin
// +build darwin
package staros package staros
import ( import (
"b612.me/stario"
"os" "os"
"syscall" "syscall"
"time" "time"
@ -12,6 +12,7 @@ import (
type FileLock struct { type FileLock struct {
fd int fd int
filepath string filepath string
locked bool
} }
func (f *FileLock) openFileForLock() error { func (f *FileLock) openFileForLock() error {
@ -19,7 +20,6 @@ func (f *FileLock) openFileForLock() error {
if err != nil { if err != nil {
return err return err
} }
f.filepath = f.filepath
f.fd = fd f.fd = fd
return nil return nil
} }
@ -31,10 +31,7 @@ func (f *FileLock) Lock(Exclusive bool) error {
} else { } else {
lockType = syscall.LOCK_SH lockType = syscall.LOCK_SH
} }
if err := f.openFileForLock(); err != nil { return f.lockWithFlags(lockType)
return err
}
return syscall.Flock(f.fd, lockType)
} }
func (f *FileLock) LockNoBlocking(Exclusive bool) error { func (f *FileLock) LockNoBlocking(Exclusive bool) error {
@ -44,38 +41,78 @@ func (f *FileLock) LockNoBlocking(Exclusive bool) error {
} else { } else {
lockType = syscall.LOCK_SH lockType = syscall.LOCK_SH
} }
return f.lockWithFlags(lockType | syscall.LOCK_NB)
}
func (f *FileLock) lockWithFlags(lockType int) error {
if f.locked {
return ERR_ALREADY_LOCKED
}
if err := f.openFileForLock(); err != nil { if err := f.openFileForLock(); err != nil {
return err return err
} }
err := syscall.Flock(f.fd, lockType|syscall.LOCK_NB) err := syscall.Flock(f.fd, lockType)
if err != nil { if err != nil {
syscall.Close(f.fd) _ = syscall.Close(f.fd)
f.fd = 0
if err == syscall.EWOULDBLOCK { if err == syscall.EWOULDBLOCK {
return ERR_ALREADY_LOCKED return ERR_ALREADY_LOCKED
} }
}
return err return err
}
f.locked = true
return nil
} }
func (f *FileLock) Unlock() error { func (f *FileLock) Unlock() error {
if f == nil || !f.locked {
return errFileLockNotLocked
}
err := syscall.Flock(f.fd, syscall.LOCK_UN) err := syscall.Flock(f.fd, syscall.LOCK_UN)
if err != nil { if err != nil {
return err return err
} }
return syscall.Close(f.fd) if err := syscall.Close(f.fd); err != nil {
return err
}
f.fd = 0
f.locked = false
return nil
} }
func (f *FileLock) LockWithTimeout(tm time.Duration, Exclusive bool) error { func (f *FileLock) LockWithTimeout(tm time.Duration, Exclusive bool) error {
return stario.WaitUntilTimeout(tm, func(tmout chan struct{}) error { if f.locked {
err := f.Lock(Exclusive) return ERR_ALREADY_LOCKED
select {
case <-tmout:
f.Unlock()
return nil
default:
} }
var lockType int
if Exclusive {
lockType = syscall.LOCK_EX
} else {
lockType = syscall.LOCK_SH
}
if tm < 0 {
return f.Lock(Exclusive)
}
deadline := time.Now().Add(tm)
for {
err := f.lockWithFlags(lockType | syscall.LOCK_NB)
if err == nil {
return nil
}
if err != ERR_ALREADY_LOCKED {
return err return err
}) }
if !time.Now().Before(deadline) {
return ERR_TIMEOUT
}
sleep := time.Millisecond * 10
if remaining := time.Until(deadline); remaining < sleep {
sleep = remaining
}
if sleep > 0 {
time.Sleep(sleep)
}
}
} }
func timespecToTime(ts syscall.Timespec) time.Time { func timespecToTime(ts syscall.Timespec) time.Time {
@ -83,9 +120,57 @@ func timespecToTime(ts syscall.Timespec) time.Time {
} }
func GetFileCreationTime(fileinfo os.FileInfo) time.Time { func GetFileCreationTime(fileinfo os.FileInfo) time.Time {
return timespecToTime(fileinfo.Sys().(*syscall.Stat_t).Ctimespec) if fileinfo == nil {
return time.Time{}
}
if stat, ok := fileinfo.Sys().(*syscall.Stat_t); ok && stat != nil {
return timespecToTime(stat.Birthtimespec)
}
return time.Time{}
} }
func GetFileAccessTime(fileinfo os.FileInfo) time.Time { func GetFileAccessTime(fileinfo os.FileInfo) time.Time {
return timespecToTime(fileinfo.Sys().(*syscall.Stat_t).Atimespec) if fileinfo == nil {
return time.Time{}
}
if stat, ok := fileinfo.Sys().(*syscall.Stat_t); ok && stat != nil {
return timespecToTime(stat.Atimespec)
}
return time.Time{}
}
func SetFileTimes(file *os.File, info os.FileInfo) {
_ = SetFileTimesE(file, info)
}
func SetFileTimesbyTime(file *os.File) {
_ = SetFileTimesbyTimeE(file)
}
func SetFileTimesE(file *os.File, info os.FileInfo) error {
if file == nil {
return errNilFile
}
if info == nil {
return errNilFileInfo
}
stat, ok := info.Sys().(*syscall.Stat_t)
if !ok || stat == nil {
return errUnsupportedFileInfo
}
atime := timespecToTime(stat.Atimespec)
mtime := info.ModTime()
return os.Chtimes(file.Name(), atime, mtime)
}
func SetFileTimesByTimeE(file *os.File) error {
return SetFileTimesbyTimeE(file)
}
func SetFileTimesbyTimeE(file *os.File) error {
if file == nil {
return errNilFile
}
now := time.Now()
return os.Chtimes(file.Name(), now, now)
} }

View File

@ -1,23 +1,139 @@
package staros package staros
import ( import (
"fmt" "errors"
"os" "os"
"path/filepath"
"runtime"
"testing" "testing"
"time" "time"
) )
func Test_FileLock(t *testing.T) { func Test_FileLock(t *testing.T) {
filename := "./test.file" filename := filepath.Join(t.TempDir(), "test.file")
lock := NewFileLock(filename) lock := NewFileLock(filename)
lock2 := NewFileLock(filename) lock2 := NewFileLock(filename)
fmt.Println("lock1", lock.LockNoBlocking(false)) if err := lock.LockNoBlocking(false); err != nil {
time.Sleep(time.Second) t.Fatal(err)
fmt.Println("lock2", lock2.LockWithTimeout(time.Second*5, false)) }
fmt.Println("unlock1", lock.Unlock()) if err := lock2.LockNoBlocking(false); err != nil {
time.Sleep(time.Second) t.Fatal(err)
fmt.Println("unlock2", lock2.Unlock()) }
fmt.Println("lock2", lock2.LockNoBlocking(true)) if err := lock.Unlock(); err != nil {
fmt.Println("unlock2", lock2.Unlock()) t.Fatal(err)
os.Remove(filename) }
if err := lock2.Unlock(); err != nil {
t.Fatal(err)
}
if err := lock2.LockNoBlocking(true); err != nil {
t.Fatal(err)
}
if err := lock2.Unlock(); err != nil {
t.Fatal(err)
}
_ = os.Remove(filename)
}
func TestFileLockExclusiveConflictTimeout(t *testing.T) {
filename := filepath.Join(t.TempDir(), "timeout.file")
lock := NewFileLock(filename)
contender := NewFileLock(filename)
if err := lock.Lock(true); err != nil {
t.Fatal(err)
}
defer lock.Unlock()
if err := contender.LockNoBlocking(true); !errors.Is(err, ERR_ALREADY_LOCKED) {
if err == nil {
_ = contender.Unlock()
}
t.Fatalf("expected non-blocking exclusive lock conflict, got %v", err)
}
start := time.Now()
if err := contender.LockWithTimeout(50*time.Millisecond, true); !errors.Is(err, ERR_TIMEOUT) {
if err == nil {
_ = contender.Unlock()
}
t.Fatalf("expected exclusive lock timeout, got %v", err)
}
if elapsed := time.Since(start); elapsed > time.Second {
t.Fatalf("lock timeout took too long: %s", elapsed)
}
if err := lock.Unlock(); err != nil {
t.Fatal(err)
}
if err := contender.LockWithTimeout(time.Second, true); err != nil {
t.Fatalf("expected lock after owner unlock, got %v", err)
}
if err := contender.Unlock(); err != nil {
t.Fatal(err)
}
}
func TestFileLockUnlockWithoutSuccessfulLock(t *testing.T) {
filename := filepath.Join(t.TempDir(), "unlock-state.file")
lock := NewFileLock(filename)
if err := lock.Unlock(); !errors.Is(err, errFileLockNotLocked) {
t.Fatalf("expected unlock without lock error, got %v", err)
}
owner := NewFileLock(filename)
contender := NewFileLock(filename)
if err := owner.Lock(true); err != nil {
t.Fatal(err)
}
defer owner.Unlock()
if err := contender.LockNoBlocking(true); !errors.Is(err, ERR_ALREADY_LOCKED) {
if err == nil {
_ = contender.Unlock()
}
t.Fatalf("expected lock conflict, got %v", err)
}
if err := contender.Unlock(); !errors.Is(err, errFileLockNotLocked) {
t.Fatalf("failed lock attempt should not be unlockable, got %v", err)
}
if err := owner.Unlock(); err != nil {
t.Fatal(err)
}
}
func TestFileLockRejectsSecondLockOnSameObject(t *testing.T) {
filename := filepath.Join(t.TempDir(), "same-object.file")
lock := NewFileLock(filename)
if err := lock.Lock(true); err != nil {
t.Fatal(err)
}
defer lock.Unlock()
for name, fn := range map[string]func() error{
"Lock": func() error { return lock.Lock(true) },
"LockNoBlocking": func() error { return lock.LockNoBlocking(true) },
"LockWithTimeout": func() error { return lock.LockWithTimeout(time.Second, true) },
} {
if err := fn(); !errors.Is(err, ERR_ALREADY_LOCKED) {
if err == nil {
_ = lock.Unlock()
}
t.Fatalf("expected %s on same lock object to reject second lock, got %v", name, err)
}
}
}
func TestGetFileCreationTimeLinuxUnavailable(t *testing.T) {
if runtime.GOOS != "linux" {
t.Skip("linux-only creation time fallback")
}
filename := filepath.Join(t.TempDir(), "creation-time.file")
if err := os.WriteFile(filename, []byte("demo"), 0o644); err != nil {
t.Fatal(err)
}
info, err := os.Stat(filename)
if err != nil {
t.Fatal(err)
}
if got := GetFileCreationTime(info); !got.IsZero() {
t.Fatalf("linux FileInfo should not report synthetic creation time, got %s", got)
}
} }

View File

@ -1,9 +1,10 @@
//+build linux //go:build linux
// +build linux
package staros package staros
import ( import (
"b612.me/stario" "golang.org/x/sys/unix"
"os" "os"
"syscall" "syscall"
"time" "time"
@ -12,6 +13,7 @@ import (
type FileLock struct { type FileLock struct {
fd int fd int
filepath string filepath string
locked bool
} }
func timespecToTime(ts syscall.Timespec) time.Time { func timespecToTime(ts syscall.Timespec) time.Time {
@ -19,11 +21,66 @@ func timespecToTime(ts syscall.Timespec) time.Time {
} }
func GetFileCreationTime(fileinfo os.FileInfo) time.Time { func GetFileCreationTime(fileinfo os.FileInfo) time.Time {
return timespecToTime(fileinfo.Sys().(*syscall.Stat_t).Ctim) if fileinfo == nil {
return time.Time{}
}
// Linux os.FileInfo/syscall.Stat_t does not expose a stable birth time.
// Returning ctime here would be wrong because it tracks inode changes.
return time.Time{}
} }
func GetFileAccessTime(fileinfo os.FileInfo) time.Time { func GetFileAccessTime(fileinfo os.FileInfo) time.Time {
return timespecToTime(fileinfo.Sys().(*syscall.Stat_t).Atim) if fileinfo == nil {
return time.Time{}
}
if stat, ok := fileinfo.Sys().(*syscall.Stat_t); ok && stat != nil {
return timespecToTime(stat.Atim)
}
return time.Time{}
}
func SetFileTimes(file *os.File, info os.FileInfo) {
_ = SetFileTimesE(file, info)
}
func SetFileTimesbyTime(file *os.File) {
_ = SetFileTimesbyTimeE(file)
}
func SetFileTimesE(file *os.File, info os.FileInfo) error {
if file == nil {
return errNilFile
}
if info == nil {
return errNilFileInfo
}
stat, ok := info.Sys().(*syscall.Stat_t)
if !ok || stat == nil {
return errUnsupportedFileInfo
}
atime := timespecToTime(stat.Atim)
mtime := info.ModTime()
return setFileTimes(file.Name(), atime, mtime)
}
func SetFileTimesByTimeE(file *os.File) error {
return SetFileTimesbyTimeE(file)
}
func SetFileTimesbyTimeE(file *os.File) error {
if file == nil {
return errNilFile
}
now := time.Now()
return setFileTimes(file.Name(), now, now)
}
func setFileTimes(path string, atime, mtime time.Time) error {
ts := [2]unix.Timespec{
unix.NsecToTimespec(atime.UnixNano()),
unix.NsecToTimespec(mtime.UnixNano()),
}
return unix.UtimesNanoAt(unix.AT_FDCWD, path, ts[:], unix.AT_SYMLINK_NOFOLLOW)
} }
func (f *FileLock) openFileForLock() error { func (f *FileLock) openFileForLock() error {
@ -31,7 +88,6 @@ func (f *FileLock) openFileForLock() error {
if err != nil { if err != nil {
return err return err
} }
f.filepath = f.filepath
f.fd = fd f.fd = fd
return nil return nil
} }
@ -43,10 +99,7 @@ func (f *FileLock) Lock(Exclusive bool) error {
} else { } else {
lockType = syscall.LOCK_SH lockType = syscall.LOCK_SH
} }
if err := f.openFileForLock(); err != nil { return f.lockWithFlags(lockType)
return err
}
return syscall.Flock(f.fd, lockType)
} }
func (f *FileLock) LockNoBlocking(Exclusive bool) error { func (f *FileLock) LockNoBlocking(Exclusive bool) error {
@ -56,36 +109,76 @@ func (f *FileLock) LockNoBlocking(Exclusive bool) error {
} else { } else {
lockType = syscall.LOCK_SH lockType = syscall.LOCK_SH
} }
return f.lockWithFlags(lockType | syscall.LOCK_NB)
}
func (f *FileLock) lockWithFlags(lockType int) error {
if f.locked {
return ERR_ALREADY_LOCKED
}
if err := f.openFileForLock(); err != nil { if err := f.openFileForLock(); err != nil {
return err return err
} }
err := syscall.Flock(f.fd, lockType|syscall.LOCK_NB) err := syscall.Flock(f.fd, lockType)
if err != nil { if err != nil {
syscall.Close(f.fd) _ = syscall.Close(f.fd)
f.fd = 0
if err == syscall.EWOULDBLOCK { if err == syscall.EWOULDBLOCK {
return ERR_ALREADY_LOCKED return ERR_ALREADY_LOCKED
} }
}
return err return err
}
f.locked = true
return nil
} }
func (f *FileLock) Unlock() error { func (f *FileLock) Unlock() error {
if f == nil || !f.locked {
return errFileLockNotLocked
}
err := syscall.Flock(f.fd, syscall.LOCK_UN) err := syscall.Flock(f.fd, syscall.LOCK_UN)
if err != nil { if err != nil {
return err return err
} }
return syscall.Close(f.fd) if err := syscall.Close(f.fd); err != nil {
return err
}
f.fd = 0
f.locked = false
return nil
} }
func (f *FileLock) LockWithTimeout(tm time.Duration, Exclusive bool) error { func (f *FileLock) LockWithTimeout(tm time.Duration, Exclusive bool) error {
return stario.WaitUntilTimeout(tm, func(tmout chan struct{}) error { if f.locked {
err := f.Lock(Exclusive) return ERR_ALREADY_LOCKED
select {
case <-tmout:
f.Unlock()
return nil
default:
} }
var lockType int
if Exclusive {
lockType = syscall.LOCK_EX
} else {
lockType = syscall.LOCK_SH
}
if tm < 0 {
return f.Lock(Exclusive)
}
deadline := time.Now().Add(tm)
for {
err := f.lockWithFlags(lockType | syscall.LOCK_NB)
if err == nil {
return nil
}
if err != ERR_ALREADY_LOCKED {
return err return err
}) }
if !time.Now().Before(deadline) {
return ERR_TIMEOUT
}
sleep := time.Millisecond * 10
if remaining := time.Until(deadline); remaining < sleep {
sleep = remaining
}
if sleep > 0 {
time.Sleep(sleep)
}
}
} }

View File

@ -1,9 +1,11 @@
//go:build windows
// +build windows // +build windows
package staros package staros
import ( import (
"b612.me/win32api" "b612.me/win32api"
"golang.org/x/sys/windows"
"os" "os"
"syscall" "syscall"
"time" "time"
@ -12,23 +14,89 @@ import (
type FileLock struct { type FileLock struct {
filepath string filepath string
handle win32api.HANDLE handle win32api.HANDLE
locked bool
} }
func GetFileCreationTime(fileinfo os.FileInfo) time.Time { func GetFileCreationTime(fileinfo os.FileInfo) time.Time {
d := fileinfo.Sys().(*syscall.Win32FileAttributeData) if fileinfo == nil {
return time.Unix(0, d.CreationTime.Nanoseconds()) return time.Time{}
}
if data, ok := fileinfo.Sys().(*syscall.Win32FileAttributeData); ok && data != nil {
return time.Unix(0, data.CreationTime.Nanoseconds())
}
return time.Time{}
} }
func GetFileAccessTime(fileinfo os.FileInfo) time.Time { func GetFileAccessTime(fileinfo os.FileInfo) time.Time {
d := fileinfo.Sys().(*syscall.Win32FileAttributeData) if fileinfo == nil {
return time.Unix(0, d.LastAccessTime.Nanoseconds()) return time.Time{}
}
if data, ok := fileinfo.Sys().(*syscall.Win32FileAttributeData); ok && data != nil {
return time.Unix(0, data.LastAccessTime.Nanoseconds())
}
return time.Time{}
} }
func SetFileTimes(file *os.File, info os.FileInfo) { func SetFileTimes(file *os.File, info os.FileInfo) {
_ = SetFileTimesE(file, info)
} }
func SetFileTimesbyTime(file *os.File) { func SetFileTimesbyTime(file *os.File) {
_ = SetFileTimesbyTimeE(file)
}
func SetFileTimesE(file *os.File, info os.FileInfo) error {
if file == nil {
return errNilFile
}
if info == nil {
return errNilFileInfo
}
data, ok := info.Sys().(*syscall.Win32FileAttributeData)
if !ok || data == nil {
return errUnsupportedFileInfo
}
ctime := time.Unix(0, data.CreationTime.Nanoseconds())
atime := time.Unix(0, data.LastAccessTime.Nanoseconds())
mtime := info.ModTime()
return setFileTimes(file.Name(), ctime, atime, mtime)
}
func SetFileTimesByTimeE(file *os.File) error {
return SetFileTimesbyTimeE(file)
}
func SetFileTimesbyTimeE(file *os.File) error {
if file == nil {
return errNilFile
}
now := time.Now()
return setFileTimes(file.Name(), now, now, now)
}
func setFileTimes(path string, ctime, atime, mtime time.Time) error {
path16, err := windows.UTF16PtrFromString(path)
if err != nil {
return err
}
handle, err := windows.CreateFile(
path16,
windows.FILE_WRITE_ATTRIBUTES,
windows.FILE_SHARE_READ|windows.FILE_SHARE_WRITE|windows.FILE_SHARE_DELETE,
nil,
windows.OPEN_EXISTING,
windows.FILE_FLAG_BACKUP_SEMANTICS,
0,
)
if err != nil {
return err
}
defer windows.CloseHandle(handle)
ctimeFt := windows.NsecToFiletime(ctime.UnixNano())
atimeFt := windows.NsecToFiletime(atime.UnixNano())
mtimeFt := windows.NsecToFiletime(mtime.UnixNano())
return windows.SetFileTime(handle, &ctimeFt, &atimeFt, &mtimeFt)
} }
@ -53,21 +121,27 @@ func (f *FileLock) openFileForLock() error {
} }
func (f *FileLock) lockForTimeout(timeout time.Duration, lockType win32api.DWORD) error { func (f *FileLock) lockForTimeout(timeout time.Duration, lockType win32api.DWORD) error {
if f.locked {
return ERR_ALREADY_LOCKED
}
var err error var err error
if err = f.openFileForLock(); err != nil { if err = f.openFileForLock(); err != nil {
return err return err
} }
event, err := win32api.CreateEventW(nil, true, false, nil) event, err := win32api.CreateEventW(nil, true, false, nil)
if err != nil { if err != nil {
_ = f.closeHandle()
return err return err
} }
myEvent := &syscall.Overlapped{HEvent: syscall.Handle(event)} myEvent := &syscall.Overlapped{HEvent: syscall.Handle(event)}
defer syscall.CloseHandle(myEvent.HEvent) defer syscall.CloseHandle(myEvent.HEvent)
_, err = win32api.LockFileEx(f.handle, lockType, 0, 1, 0, myEvent) _, err = win32api.LockFileEx(f.handle, lockType, 0, 1, 0, myEvent)
if err == nil { if err == nil {
f.locked = true
return nil return nil
} }
if err != syscall.ERROR_IO_PENDING { if err != syscall.ERROR_IO_PENDING {
_ = f.closeHandle()
return err return err
} }
millis := uint32(syscall.INFINITE) millis := uint32(syscall.INFINITE)
@ -78,12 +152,13 @@ func (f *FileLock) lockForTimeout(timeout time.Duration, lockType win32api.DWORD
switch s { switch s {
case syscall.WAIT_OBJECT_0: case syscall.WAIT_OBJECT_0:
// success! // success!
f.locked = true
return nil return nil
case syscall.WAIT_TIMEOUT: case syscall.WAIT_TIMEOUT:
f.Unlock() _ = f.closeHandle()
return ERR_TIMEOUT return ERR_TIMEOUT
default: default:
f.Unlock() _ = f.closeHandle()
return err return err
} }
} }
@ -95,7 +170,7 @@ func (f *FileLock) Lock(Exclusive bool) error {
} else { } else {
lockType = 0 lockType = 0
} }
return f.lockForTimeout(0, lockType) return f.lockForTimeout(-1, lockType)
} }
func (f *FileLock) LockWithTimeout(tm time.Duration, Exclusive bool) error { func (f *FileLock) LockWithTimeout(tm time.Duration, Exclusive bool) error {
@ -119,5 +194,24 @@ func (f *FileLock) LockNoBlocking(Exclusive bool) error {
} }
func (f *FileLock) Unlock() error { func (f *FileLock) Unlock() error {
return syscall.Close(syscall.Handle(f.handle)) if f == nil || !f.locked {
return errFileLockNotLocked
}
if err := f.closeHandle(); err != nil {
return err
}
f.locked = false
return nil
}
func (f *FileLock) closeHandle() error {
if f == nil || f.handle == 0 {
return nil
}
err := syscall.Close(syscall.Handle(f.handle))
if err != nil {
return err
}
f.handle = 0
return nil
} }

13
go.mod
View File

@ -1,10 +1,15 @@
module b612.me/staros module b612.me/staros
go 1.16 go 1.18
require ( require (
b612.me/stario v0.0.10 b612.me/win32api v0.0.4
b612.me/win32api v0.0.2 b612.me/wincmd v0.1.0
b612.me/wincmd v0.0.4
golang.org/x/sys v0.24.0 golang.org/x/sys v0.24.0
) )
require (
b612.me/stario v0.0.11 // indirect
golang.org/x/crypto v0.26.0 // indirect
golang.org/x/term v0.23.0 // indirect
)

12
go.sum
View File

@ -1,9 +1,9 @@
b612.me/stario v0.0.10 h1:+cIyiDCBCjUfodMJDp4FLs+2E1jo7YENkN+sMEe6550= b612.me/stario v0.0.11 h1:H5SN5G36ZlW7Lu5co3CWK59eHVJduqHSa9a29Cx5ExQ=
b612.me/stario v0.0.10/go.mod h1:1Owmu9jzKWgs4VsmeI8YWlGwLrCwPNM/bYpxkyn+MMk= b612.me/stario v0.0.11/go.mod h1:1Owmu9jzKWgs4VsmeI8YWlGwLrCwPNM/bYpxkyn+MMk=
b612.me/win32api v0.0.2 h1:5PwvPR5fYs3a/v+LjYdtRif+5Q04zRGLTVxmCYNjCpA= b612.me/win32api v0.0.4 h1:V3LgCTbl8UF0Tb1UJDXl8+F/404yLA0XtC/131KmQ7c=
b612.me/win32api v0.0.2/go.mod h1:sj66sFJDKElEjOR+0YhdSW6b4kq4jsXu4T5/Hnpyot0= b612.me/win32api v0.0.4/go.mod h1:sj66sFJDKElEjOR+0YhdSW6b4kq4jsXu4T5/Hnpyot0=
b612.me/wincmd v0.0.4 h1:fv9p1V8mw2HdUjaoZBWZy0T41JftueyLxAuch1MgtdI= b612.me/wincmd v0.1.0 h1:hLOvoIvsPhesb7XbN0l+pcfu4YNWog7YYw11MAkOiDs=
b612.me/wincmd v0.0.4/go.mod h1:o3yPoE+DpVPHGKl/q1WT1C8OaIVwHEnpeNgMFqzlwD8= b612.me/wincmd v0.1.0/go.mod h1:kSUbCBCBciJQZi8V2gP78ZEtt8yUHaLatl/5X+V+4Fc=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=

File diff suppressed because it is too large Load Diff

View File

@ -2,11 +2,15 @@ package hosts
import ( import (
"fmt" "fmt"
"os"
"path/filepath"
"strings"
"testing" "testing"
) )
func Test_Hosts(t *testing.T) { func Test_Hosts(t *testing.T) {
var h = NewHosts() var h = NewHosts()
tmpDir := t.TempDir()
err := h.Parse("./test_hosts.txt") err := h.Parse("./test_hosts.txt")
if err != nil { if err != nil {
t.Error(err) t.Error(err)
@ -89,7 +93,7 @@ func Test_Hosts(t *testing.T) {
t.Log(data) t.Log(data)
} }
err = h.SaveAs("./test_hosts_01.txt") err = h.SaveAs(filepath.Join(tmpDir, "test_hosts_01.txt"))
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
@ -133,7 +137,7 @@ func Test_Hosts(t *testing.T) {
t.Error("Expected 1 got ", data) t.Error("Expected 1 got ", data)
} }
err = h.SaveAs("./test_hosts_02.txt") err = h.SaveAs(filepath.Join(tmpDir, "test_hosts_02.txt"))
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
@ -152,3 +156,577 @@ func BenchmarkAddHosts(b *testing.B) {
} }
} }
} }
func TestParseHandlesEmptyAndNoTrailingNewline(t *testing.T) {
t.Run("empty file", func(t *testing.T) {
h := NewHosts()
path := filepath.Join(t.TempDir(), "hosts.empty")
if err := os.WriteFile(path, nil, 0o644); err != nil {
t.Fatal(err)
}
if err := h.Parse(path); err != nil {
t.Fatal(err)
}
if got := h.List(); len(got) != 0 {
t.Fatalf("expected empty hosts list, got %d entries", len(got))
}
})
t.Run("last line without newline", func(t *testing.T) {
h := NewHosts()
path := filepath.Join(t.TempDir(), "hosts.nonewline")
if err := os.WriteFile(path, []byte("1.2.3.4 example.test"), 0o644); err != nil {
t.Fatal(err)
}
if err := h.Parse(path); err != nil {
t.Fatal(err)
}
if got := h.ListIPsByHost("example.test"); len(got) != 1 || got[0] != "1.2.3.4" {
t.Fatalf("expected last line to be parsed, got %v", got)
}
node, err := h.GetLatestNode()
if err != nil {
t.Fatal(err)
}
if node.NextUID() != 0 {
t.Fatalf("expected last node next uid 0, got %d", node.NextUID())
}
})
}
func TestAddHostsAndAddNodeWorkOnEmptyModel(t *testing.T) {
t.Run("add hosts", func(t *testing.T) {
h := NewHosts()
if err := h.AddHosts("1.2.3.4", "example.test"); err != nil {
t.Fatal(err)
}
if got := h.ListFirstIPByHost("example.test"); got != "1.2.3.4" {
t.Fatalf("expected inserted host ip, got %q", got)
}
out, err := h.Build()
if err != nil {
t.Fatal(err)
}
if !strings.Contains(string(out), "1.2.3.4 example.test") {
t.Fatalf("unexpected build output: %q", out)
}
})
t.Run("add node", func(t *testing.T) {
h := NewHosts()
node := &HostNode{}
node.SetIP("5.6.7.8")
node.SetHosts("node.test")
if err := h.AddNode(node); err != nil {
t.Fatal(err)
}
if node.UID() == 0 {
t.Fatal("expected node uid to be assigned")
}
if got := h.ListFirstIPByHost("node.test"); got != "5.6.7.8" {
t.Fatalf("expected inserted node ip, got %q", got)
}
})
}
func TestInsertNodeByDataInsertsAndLinksNode(t *testing.T) {
h := NewHosts()
path := filepath.Join(t.TempDir(), "hosts.insert")
if err := os.WriteFile(path, []byte("2.2.2.2 anchor.test\n"), 0o644); err != nil {
t.Fatal(err)
}
if err := h.Parse(path); err != nil {
t.Fatal(err)
}
anchor, err := h.GetFirstNode()
if err != nil {
t.Fatal(err)
}
if err := h.InsertNodeByData(anchor, true, "before", "1.1.1.1", "before.test"); err != nil {
t.Fatal(err)
}
if err := h.InsertNodeByData(anchor, false, "after", "3.3.3.3", "after.test"); err != nil {
t.Fatal(err)
}
nodes := h.List()
if len(nodes) != 3 {
t.Fatalf("expected 3 nodes after insert, got %d", len(nodes))
}
if nodes[0].IP() != "1.1.1.1" || nodes[1].IP() != "2.2.2.2" || nodes[2].IP() != "3.3.3.3" {
t.Fatalf("unexpected node order: %q, %q, %q", nodes[0].IP(), nodes[1].IP(), nodes[2].IP())
}
if got := h.ListFirstIPByHost("before.test"); got != "1.1.1.1" {
t.Fatalf("expected before node to be indexed, got %q", got)
}
if got := h.ListFirstIPByHost("after.test"); got != "3.3.3.3" {
t.Fatalf("expected after node to be indexed, got %q", got)
}
if nodes[0].NextUID() != nodes[1].UID() || nodes[1].LastUID() != nodes[0].UID() {
t.Fatalf("before/anchor linkage broken: before.next=%d anchor.uid=%d anchor.last=%d", nodes[0].NextUID(), nodes[1].UID(), nodes[1].LastUID())
}
if nodes[1].NextUID() != nodes[2].UID() || nodes[2].LastUID() != nodes[1].UID() {
t.Fatalf("anchor/after linkage broken: anchor.next=%d after.uid=%d after.last=%d", nodes[1].NextUID(), nodes[2].UID(), nodes[2].LastUID())
}
}
func TestInsertNodeByDataRejectsNilAnchor(t *testing.T) {
h := NewHosts()
path := filepath.Join(t.TempDir(), "hosts.insert.nil")
if err := os.WriteFile(path, []byte("2.2.2.2 anchor.test\n"), 0o644); err != nil {
t.Fatal(err)
}
if err := h.Parse(path); err != nil {
t.Fatal(err)
}
if err := h.InsertNodeByData(nil, true, "before", "1.1.1.1", "before.test"); err == nil {
t.Fatal("expected nil anchor error")
}
}
func TestSetIPHostsUpdatesReverseIndex(t *testing.T) {
h := NewHosts()
if err := h.AddHosts("1.2.3.4", "old.test"); err != nil {
t.Fatal(err)
}
if err := h.SetIPHosts("1.2.3.4", "new.test"); err != nil {
t.Fatal(err)
}
if got := h.ListIPsByHost("new.test"); len(got) != 1 || got[0] != "1.2.3.4" {
t.Fatalf("expected new reverse index, got %v", got)
}
if got := h.ListIPsByHost("old.test"); len(got) != 0 {
t.Fatalf("expected old reverse index to be removed, got %v", got)
}
}
func TestSetIPHostsDeduplicatesHosts(t *testing.T) {
h := NewHosts()
if err := h.AddHosts("1.2.3.4", "old.test"); err != nil {
t.Fatal(err)
}
if err := h.SetIPHosts("1.2.3.4", "new.test", "new.test"); err != nil {
t.Fatal(err)
}
if got := h.ListHostsByIP("1.2.3.4"); len(got) != 1 || got[0] != "new.test" {
t.Fatalf("expected deduplicated ip mapping, got %v", got)
}
if got := h.ListIPsByHost("new.test"); len(got) != 1 || got[0] != "1.2.3.4" {
t.Fatalf("expected deduplicated reverse index, got %v", got)
}
}
func TestSetIPHostsReplacesMultipleSameIPNodes(t *testing.T) {
h := NewHosts()
if err := h.AddHosts("1.2.3.4", "first.test"); err != nil {
t.Fatal(err)
}
if err := h.AddHosts("1.2.3.4", "second.test"); err != nil {
t.Fatal(err)
}
if err := h.AddHosts("5.6.7.8", "tail.test"); err != nil {
t.Fatal(err)
}
if err := h.SetIPHosts("1.2.3.4", "new.test"); err != nil {
t.Fatal(err)
}
if got := h.ListHostsByIP("1.2.3.4"); len(got) != 1 || got[0] != "new.test" {
t.Fatalf("expected replaced same-ip mappings, got %v", got)
}
if got := h.ListIPsByHost("first.test"); len(got) != 0 {
t.Fatalf("expected first old host to disappear, got %v", got)
}
if got := h.ListIPsByHost("second.test"); len(got) != 0 {
t.Fatalf("expected second old host to disappear, got %v", got)
}
if got := h.ListFirstIPByHost("tail.test"); got != "5.6.7.8" {
t.Fatalf("expected tail node to remain linked, got %q", got)
}
nodes := h.List()
if len(nodes) != 2 || nodes[0].IP() != "5.6.7.8" || nodes[1].IP() != "1.2.3.4" {
t.Fatalf("unexpected node list after SetIPHosts: %#v", nodes)
}
if nodes[0].NextUID() != nodes[1].UID() || nodes[1].LastUID() != nodes[0].UID() {
t.Fatalf("remaining node linkage broken: first.next=%d second.uid=%d second.last=%d", nodes[0].NextUID(), nodes[1].UID(), nodes[1].LastUID())
}
}
func TestSetHostIPsReplacesMappingsInOneOperation(t *testing.T) {
h := NewHosts()
if err := h.AddHosts("1.2.3.4", "old.test"); err != nil {
t.Fatal(err)
}
if err := h.SetHostIPs("old.test", "2.2.2.2", "3.3.3.3"); err != nil {
t.Fatal(err)
}
if got := h.ListIPsByHost("old.test"); len(got) != 2 || got[0] != "2.2.2.2" || got[1] != "3.3.3.3" {
t.Fatalf("expected replaced host ip mappings, got %v", got)
}
if got := h.ListHostsByIP("1.2.3.4"); len(got) != 0 {
t.Fatalf("expected old ip mapping to be removed, got %v", got)
}
}
func TestSetIPHostsRejectsInvalidInputWithoutMutating(t *testing.T) {
tests := []struct {
name string
ip string
hosts []string
}{
{name: "bad ip", ip: "bad-ip", hosts: []string{"new.test"}},
{name: "empty host", ip: "1.2.3.4", hosts: []string{""}},
{name: "comment host", ip: "1.2.3.4", hosts: []string{"#bad.test"}},
{name: "missing host", ip: "1.2.3.4"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
h := NewHosts()
if err := h.AddHosts("1.2.3.4", "old.test"); err != nil {
t.Fatal(err)
}
if err := h.SetIPHosts(tt.ip, tt.hosts...); err == nil {
t.Fatal("expected invalid SetIPHosts input to fail")
}
if got := h.ListIPsByHost("old.test"); len(got) != 1 || got[0] != "1.2.3.4" {
t.Fatalf("old host mapping should remain after failed SetIPHosts, got %v", got)
}
if got := h.ListHostsByIP("1.2.3.4"); len(got) != 1 || got[0] != "old.test" {
t.Fatalf("ip index should remain after failed SetIPHosts, got %v", got)
}
if got := h.ListIPsByHost("new.test"); len(got) != 0 {
t.Fatalf("failed SetIPHosts should not add new host, got %v", got)
}
})
}
}
func TestSetHostIPsRejectsInvalidInputWithoutMutating(t *testing.T) {
tests := []struct {
name string
host string
ips []string
}{
{name: "empty host", host: "", ips: []string{"2.2.2.2"}},
{name: "comment host", host: "#old.test", ips: []string{"2.2.2.2"}},
{name: "bad ip", host: "old.test", ips: []string{"2.2.2.2", "bad-ip"}},
{name: "missing ip", host: "old.test"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
h := NewHosts()
if err := h.AddHosts("1.2.3.4", "old.test"); err != nil {
t.Fatal(err)
}
if err := h.SetHostIPs(tt.host, tt.ips...); err == nil {
t.Fatal("expected invalid SetHostIPs input to fail")
}
if got := h.ListIPsByHost("old.test"); len(got) != 1 || got[0] != "1.2.3.4" {
t.Fatalf("old host mapping should remain after failed SetHostIPs, got %v", got)
}
if got := h.ListHostsByIP("1.2.3.4"); len(got) != 1 || got[0] != "old.test" {
t.Fatalf("ip index should remain after failed SetHostIPs, got %v", got)
}
if got := h.ListHostsByIP("2.2.2.2"); len(got) != 0 {
t.Fatalf("failed SetHostIPs should not add partial ip mapping, got %v", got)
}
})
}
}
func TestRemoveIPHostsKeepsSameIPOtherNodes(t *testing.T) {
h := NewHosts()
if err := h.AddHosts("1.2.3.4", "first.test"); err != nil {
t.Fatal(err)
}
if err := h.AddHosts("1.2.3.4", "second.test"); err != nil {
t.Fatal(err)
}
if err := h.RemoveIPHosts("1.2.3.4", "first.test"); err != nil {
t.Fatal(err)
}
if got := h.ListIPsByHost("first.test"); len(got) != 0 {
t.Fatalf("expected removed host to disappear, got %v", got)
}
if got := h.ListIPsByHost("second.test"); len(got) != 1 || got[0] != "1.2.3.4" {
t.Fatalf("expected same-ip sibling node to stay indexed, got %v", got)
}
if got := h.ListHostsByIP("1.2.3.4"); len(got) != 1 || got[0] != "second.test" {
t.Fatalf("expected ip index to keep sibling host, got %v", got)
}
}
func TestRemoveHostsKeepsSameIPOtherNodes(t *testing.T) {
h := NewHosts()
if err := h.AddHosts("1.2.3.4", "first.test"); err != nil {
t.Fatal(err)
}
if err := h.AddHosts("1.2.3.4", "second.test"); err != nil {
t.Fatal(err)
}
if err := h.RemoveHosts("first.test"); err != nil {
t.Fatal(err)
}
if got := h.ListHostsByIP("1.2.3.4"); len(got) != 1 || got[0] != "second.test" {
t.Fatalf("expected ip index to keep sibling host after RemoveHosts, got %v", got)
}
}
func TestRemoveIPsUnlinksAdjacentNodes(t *testing.T) {
h := NewHosts()
if err := h.AddHosts("1.2.3.4", "first.test"); err != nil {
t.Fatal(err)
}
if err := h.AddHosts("1.2.3.4", "second.test"); err != nil {
t.Fatal(err)
}
if err := h.AddHosts("5.6.7.8", "tail.test"); err != nil {
t.Fatal(err)
}
if err := h.RemoveIPs("1.2.3.4"); err != nil {
t.Fatal(err)
}
nodes := h.List()
if len(nodes) != 1 || nodes[0].IP() != "5.6.7.8" || nodes[0].LastUID() != 0 || nodes[0].NextUID() != 0 {
t.Fatalf("expected only tail node with clean links, got %#v", nodes)
}
if got := h.ListHostsByIP("1.2.3.4"); len(got) != 0 {
t.Fatalf("expected removed ip index to be empty, got %v", got)
}
if got := h.ListFirstIPByHost("tail.test"); got != "5.6.7.8" {
t.Fatalf("expected tail reverse index to remain, got %q", got)
}
}
func TestAddHostsRejectsInvalidInput(t *testing.T) {
h := NewHosts()
if err := h.AddHosts("not-an-ip", "bad.test"); err == nil {
t.Fatal("expected invalid ip error")
}
if got := h.ListHostsByIP("not-an-ip"); len(got) != 0 {
t.Fatalf("invalid ip should not be indexed, got %v", got)
}
if err := h.AddHosts("1.2.3.4", ""); err == nil {
t.Fatal("expected empty host error")
}
if got := h.ListHostsByIP("1.2.3.4"); len(got) != 0 {
t.Fatalf("empty host should not be indexed, got %v", got)
}
}
func TestInsertNodeByDataRejectsInvalidHostDataWithoutMutating(t *testing.T) {
h := NewHosts()
if err := h.AddHosts("2.2.2.2", "anchor.test"); err != nil {
t.Fatal(err)
}
anchor, err := h.GetFirstNode()
if err != nil {
t.Fatal(err)
}
tests := []struct {
name string
ip string
hosts []string
}{
{name: "bad ip", ip: "not-an-ip", hosts: []string{"bad.test"}},
{name: "empty host", ip: "1.1.1.1", hosts: []string{""}},
{name: "comment host", ip: "1.1.1.1", hosts: []string{"#bad.test"}},
{name: "missing host", ip: "1.1.1.1"},
}
for _, tt := range tests {
if err := h.InsertNodeByData(anchor, false, "", tt.ip, tt.hosts...); err == nil {
t.Fatalf("%s: expected error", tt.name)
}
}
nodes := h.List()
if len(nodes) != 1 || nodes[0].IP() != "2.2.2.2" {
t.Fatalf("invalid insert should not mutate node list: %#v", nodes)
}
if got := h.ListHostsByIP("1.1.1.1"); len(got) != 0 {
t.Fatalf("invalid insert should not mutate ip index: %v", got)
}
if err := h.InsertNodeByData(anchor, true, "comment-only", ""); err != nil {
t.Fatalf("comment-only insert should remain valid: %v", err)
}
nodes = h.List()
if len(nodes) != 2 || !nodes[0].OnlyComment() || nodes[1].IP() != "2.2.2.2" {
t.Fatalf("comment-only insert mismatch: %#v", nodes)
}
}
func TestInsertNodeByDataRejectsEmptyNodeWithoutMutating(t *testing.T) {
h := NewHosts()
if err := h.AddHosts("2.2.2.2", "anchor.test"); err != nil {
t.Fatal(err)
}
anchor, err := h.GetFirstNode()
if err != nil {
t.Fatal(err)
}
if err := h.InsertNodeByData(anchor, true, "", ""); err == nil {
t.Fatal("expected empty insert to fail")
}
nodes := h.List()
if len(nodes) != 1 || nodes[0].IP() != "2.2.2.2" {
t.Fatalf("empty insert should not mutate node list: %#v", nodes)
}
out, err := h.Build()
if err != nil {
t.Fatal(err)
}
if got, want := string(out), "2.2.2.2 anchor.test"+lineBreaker; got != want {
t.Fatalf("empty insert should not change output: got %q want %q", got, want)
}
}
func TestEmptyHostsBuildAndSaveAs(t *testing.T) {
h := NewHosts()
path := filepath.Join(t.TempDir(), "hosts.empty")
if err := os.WriteFile(path, nil, 0o644); err != nil {
t.Fatal(err)
}
if err := h.Parse(path); err != nil {
t.Fatal(err)
}
out, err := h.Build()
if err != nil {
t.Fatal(err)
}
if len(out) != 0 {
t.Fatalf("expected empty build output, got %q", out)
}
outPath := filepath.Join(t.TempDir(), "hosts.out")
if err := h.SaveAs(outPath); err != nil {
t.Fatal(err)
}
saved, err := os.ReadFile(outPath)
if err != nil {
t.Fatal(err)
}
if len(saved) != 0 {
t.Fatalf("expected empty saved file, got %q", saved)
}
}
func TestParsePreservesBlankAndRawLines(t *testing.T) {
h := NewHosts()
path := filepath.Join(t.TempDir(), "hosts.raw")
input := []byte("127.0.0.1 localhost\n\nbadline\n# tail comment\n")
if err := os.WriteFile(path, input, 0o644); err != nil {
t.Fatal(err)
}
if err := h.Parse(path); err != nil {
t.Fatal(err)
}
out, err := h.Build()
if err != nil {
t.Fatal(err)
}
got := string(out)
if !strings.Contains(got, "127.0.0.1 localhost"+lineBreaker+lineBreaker+"badline"+lineBreaker) {
t.Fatalf("expected blank/raw lines to be preserved, got %q", got)
}
if !strings.Contains(got, "# tail comment"+lineBreaker) {
t.Fatalf("expected comment line to be preserved, got %q", got)
}
}
func TestHostAccessorsReturnDetachedCopies(t *testing.T) {
h := NewHosts()
if err := h.AddHosts("1.2.3.4", "example.test"); err != nil {
t.Fatal(err)
}
node, err := h.GetFirstNode()
if err != nil {
t.Fatal(err)
}
node.SetIP("9.9.9.9")
node.SetHosts("mutated.test")
if got := h.ListFirstIPByHost("example.test"); got != "1.2.3.4" {
t.Fatalf("detached copy mutated internal host index: %q", got)
}
if got := h.ListFirstIPByHost("mutated.test"); got != "" {
t.Fatalf("detached copy should not create new host index: %q", got)
}
out, err := h.Build()
if err != nil {
t.Fatal(err)
}
if strings.Contains(string(out), "9.9.9.9 mutated.test") {
t.Fatalf("detached copy leaked into build output: %q", out)
}
}
func TestUpdateNodeRejectsInvalidMutationAndPreservesState(t *testing.T) {
h := NewHosts()
if err := h.AddHosts("1.2.3.4", "example.test"); err != nil {
t.Fatal(err)
}
node, err := h.GetFirstNode()
if err != nil {
t.Fatal(err)
}
node.SetIP("bad-ip")
if err := h.UpdateNode(node); err == nil {
t.Fatal("expected invalid update to fail")
}
if got := h.ListFirstIPByHost("example.test"); got != "1.2.3.4" {
t.Fatalf("failed update should preserve previous index, got %q", got)
}
out, err := h.Build()
if err != nil {
t.Fatal(err)
}
if !strings.Contains(string(out), "1.2.3.4 example.test") || strings.Contains(string(out), "bad-ip") {
t.Fatalf("failed update should preserve previous output, got %q", out)
}
}
func TestUpdateNodeCommentOnlyStateReindexesCleanly(t *testing.T) {
h := NewHosts()
if err := h.AddHosts("1.2.3.4", "example.test"); err != nil {
t.Fatal(err)
}
node, err := h.GetFirstNode()
if err != nil {
t.Fatal(err)
}
node.SetIP("")
node.SetHosts()
node.SetComment("note")
if err := h.UpdateNode(node); err != nil {
t.Fatalf("comment-only update failed: %v", err)
}
if got := h.ListByIP(""); len(got) != 0 {
t.Fatalf("comment-only node should not be indexed under empty ip: %#v", got)
}
updated, err := h.GetNode(node.UID())
if err != nil {
t.Fatal(err)
}
if !updated.OnlyComment() {
t.Fatalf("comment-only node should keep onlyComment state: %#v", updated)
}
node = updated
node.SetIP("2.2.2.2")
node.SetHosts("restored.test")
if err := h.UpdateNode(node); err != nil {
t.Fatalf("restoring host entry failed: %v", err)
}
updated, err = h.GetNode(node.UID())
if err != nil {
t.Fatal(err)
}
if updated.OnlyComment() {
t.Fatalf("host entry should clear onlyComment after restore: %#v", updated)
}
if got := h.ListFirstIPByHost("restored.test"); got != "2.2.2.2" {
t.Fatalf("restored host index mismatch: %q", got)
}
if got := h.ListByIP(""); len(got) != 0 {
t.Fatalf("restored host should still avoid empty ip index: %#v", got)
}
}

632
math.go
View File

@ -1,305 +1,403 @@
package staros package staros
import ( import (
"errors"
"fmt" "fmt"
"math" "math"
"strconv" "strconv"
"strings" "strings"
"unicode"
) )
func Calc(math string) (float64, error) { // Calc evaluates a small frozen arithmetic expression language kept for
math = strings.Replace(math, " ", "", -1) // compatibility with older staros callers.
math = strings.ToLower(math) func Calc(expr string) (float64, error) {
if err := check(math); err != nil { parser := calcParser{input: strings.ToLower(strings.TrimSpace(expr))}
if parser.input == "" {
return 0, fmt.Errorf("empty expression")
}
value, err := parser.parseExpression()
if err != nil {
return 0, err return 0, err
} }
result,err:=calc(math) parser.skipSpace()
if err!=nil { if !parser.done() {
return 0,err return 0, fmt.Errorf("unexpected token %q at position %d", parser.peek(), parser.pos)
} }
return floatRound(result,15),nil return normalizeCalcFloat(value), nil
} }
func floatRound(f float64, n int) float64 { type calcParser struct {
format := "%." + strconv.Itoa(n) + "f" input string
res, _ := strconv.ParseFloat(fmt.Sprintf(format, f), 64) pos int
return res
} }
func check(math string) error { func (p *calcParser) parseExpression() (float64, error) {
math = strings.Replace(math, " ", "", -1) return p.parseAddSub()
math = strings.ToLower(math)
var bracketSum int
var signReady bool
for k, v := range math {
if string([]rune{v}) == "(" {
bracketSum++
}
if string([]rune{v}) == ")" {
bracketSum--
}
if bracketSum < 0 {
return fmt.Errorf("err at position %d.Reason is right bracket position not correct,except (", k)
}
if containSign(string([]rune{v})) {
if signReady {
if string([]rune{v}) != "+" && string([]rune{v}) != "-" {
return fmt.Errorf("err at position %d.Reason is sign %s not correct", k, string([]rune{v}))
}
} else {
signReady = true
continue
}
}
signReady = false
}
if bracketSum != 0 {
return fmt.Errorf("Error:right bracket is not equal as left bracket")
}
return nil
} }
func calc(math string) (float64, error) { func (p *calcParser) parseAddSub() (float64, error) {
var bracketLeft int left, err := p.parseMulDiv()
var bracketRight int
var DupStart int = -1
for pos, str := range math {
if string(str) == "(" {
bracketLeft = pos
}
if string(str) == ")" {
bracketRight = pos
break
}
}
if bracketRight == 0 && bracketLeft != 0 || (bracketLeft > bracketRight) {
return 0, fmt.Errorf("Error:bracket not correct at %d ,except )", bracketLeft)
}
if bracketRight == 0 && bracketLeft == 0 {
return calcLong(math)
}
line := math[bracketLeft+1 : bracketRight]
num, err := calcLong(line)
if err != nil { if err != nil {
return 0, err return 0, err
} }
for i := bracketLeft - 1; i >= 0; i-- { for {
if !containSign(math[i : i+1]) { p.skipSpace()
DupStart = i switch p.peek() {
continue case '+':
} p.pos++
break right, err := p.parseMulDiv()
}
if DupStart != -1 {
sign := math[DupStart:bracketLeft]
num, err := calcDuaFloat(sign, num)
if err != nil { if err != nil {
return 0, err return 0, err
} }
math = math[:DupStart] + fmt.Sprintf("%.15f", num) + math[bracketRight+1:] left += right
DupStart = -1 case '-':
} else { p.pos++
math = math[:bracketLeft] + fmt.Sprintf("%.15f", num) + math[bracketRight+1:] right, err := p.parseMulDiv()
}
return calc(math)
}
func calcLong(str string) (float64, error) {
var sigReady bool = false
var sigApply bool = false
var numPool []float64
var operPool []string
var numStr string
var oper string
if str[0:1] == "+" || str[0:1] == "-" {
sigReady = true
}
for _, stp := range str {
if sigReady && containSign(string(stp)) {
sigReady = false
sigApply = true
oper = string(stp)
continue
}
if !containSign(string(stp)) {
sigReady = false
numStr = string(append([]rune(numStr), stp))
continue
}
if !sigReady {
sigReady = true
}
if sigApply {
num, err := calcDua(oper, numStr)
if err != nil { if err != nil {
return 0, err return 0, err
} }
sigApply = false left -= right
numPool = append(numPool, num)
} else {
num, err := parseNumbic(numStr)
if err != nil {
return 0, err
}
numPool = append(numPool, num)
}
numStr = ""
operPool = append(operPool, string(stp))
}
if sigApply {
num, err := calcDua(oper, numStr)
if err != nil {
return 0, err
}
numPool = append(numPool, num)
} else {
num, err := parseNumbic(numStr)
if err != nil {
return 0, err
}
numPool = append(numPool, num)
}
return calcPool(numPool, operPool)
}
func calcPool(numPool []float64, operPool []string) (float64, error) {
if len(numPool) == 1 && len(operPool) == 0 {
return numPool[0], nil
}
if len(numPool) < len(operPool) {
return 0, errors.New(("Operate Signal Is too much"))
}
calcFunc := func(k int, v string) (float64, error) {
num, err := calcSigFloat(numPool[k], v, numPool[k+1])
if err != nil {
return 0, err
}
tmp := append(numPool[:k], num)
numPool = append(tmp, numPool[k+2:]...)
operPool = append(operPool[:k], operPool[k+1:]...)
return calcPool(numPool, operPool)
}
for k, v := range operPool {
if v == "^" {
return calcFunc(k, v)
}
}
for k, v := range operPool {
if v == "*" || v == "/" {
return calcFunc(k, v)
}
}
for k, v := range operPool {
return calcFunc(k, v)
}
return 0, nil
}
func calcSigFloat(floatA float64, b string, floatC float64) (float64, error) {
switch b {
case "+":
return floatRound(floatA + floatC,15), nil
case "-":
return floatRound(floatA - floatC,15), nil
case "*":
return floatRound(floatA * floatC,15), nil
case "/":
if floatC == 0 {
return 0, errors.New("Divisor cannot be 0")
}
return floatRound(floatA / floatC,15), nil
case "^":
return math.Pow(floatA, floatC), nil
}
return 0, fmt.Errorf("unexpect method:%s", b)
}
func calcSig(a, b, c string) (float64, error) {
floatA, err := parseNumbic(a)
if err != nil {
return 0, err
}
floatC, err := parseNumbic(c)
if err != nil {
return 0, err
}
return calcSigFloat(floatA, b, floatC)
}
func calcDuaFloat(a string, floatB float64) (float64, error) {
switch a {
case "sin":
return math.Sin(floatB), nil
case "cos":
return math.Cos(floatB), nil
case "tan":
return math.Tan(floatB), nil
case "abs":
return math.Abs(floatB), nil
case "arcsin":
return math.Asin(floatB), nil
case "arccos":
return math.Acos(floatB), nil
case "arctan":
return math.Atan(floatB), nil
case "sqrt":
return math.Sqrt(floatB), nil
case "loge":
return math.Log(floatB), nil
case "log10":
return math.Log10(floatB), nil
case "log2":
return math.Log2(floatB), nil
case "floor":
return math.Floor(floatB), nil
case "ceil":
return math.Ceil(floatB), nil
case "round":
return math.Round(floatB), nil
case "trunc":
return math.Trunc(floatB), nil
case "+":
return 0 + floatB, nil
case "-":
return 0 - floatB, nil
}
return 0, fmt.Errorf("unexpect method:%s", a)
}
func calcDua(a, b string) (float64, error) {
floatB, err := parseNumbic(b)
if err != nil {
return 0, err
}
return calcDuaFloat(a, floatB)
}
func parseNumbic(str string) (float64, error) {
switch str {
case "pi":
return float64(math.Pi), nil
case "e":
return float64(math.E), nil
default: default:
return strconv.ParseFloat(str, 64) return left, nil
}
} }
} }
func containSign(str string) bool { func (p *calcParser) parseMulDiv() (float64, error) {
var sign []string = []string{"+", "-", "*", "/", "^"} left, err := p.parsePower()
for _, v := range sign { if err != nil {
if str == v { return 0, err
return true }
for {
p.skipSpace()
switch p.peek() {
case '*':
p.pos++
right, err := p.parsePower()
if err != nil {
return 0, err
}
left *= right
case '/':
p.pos++
right, err := p.parsePower()
if err != nil {
return 0, err
}
if right == 0 {
return 0, fmt.Errorf("divisor cannot be 0")
}
left /= right
default:
return left, nil
} }
} }
return false
} }
func contain(pool []string, str string) bool { func (p *calcParser) parsePower() (float64, error) {
for _, v := range pool { left, err := p.parseUnary()
if v == str { if err != nil {
return 0, err
}
p.skipSpace()
if p.peek() != '^' {
return left, nil
}
p.pos++
right, err := p.parsePower()
if err != nil {
return 0, err
}
return math.Pow(left, right), nil
}
func (p *calcParser) parseUnary() (float64, error) {
p.skipSpace()
switch p.peek() {
case '+':
p.pos++
return p.parseUnary()
case '-':
p.pos++
value, err := p.parseUnary()
if err != nil {
return 0, err
}
return -value, nil
default:
return p.parsePrimary()
}
}
func (p *calcParser) parsePrimary() (float64, error) {
p.skipSpace()
if p.done() {
return 0, fmt.Errorf("unexpected end of expression")
}
ch := p.peek()
switch {
case ch == '(':
p.pos++
value, err := p.parseExpression()
if err != nil {
return 0, err
}
p.skipSpace()
if p.peek() != ')' {
return 0, fmt.Errorf("missing ')' at position %d", p.pos)
}
p.pos++
return value, nil
case isCalcNumberStart(p.input, p.pos):
return p.parseNumber()
case isCalcIdentStart(ch):
return p.parseIdentifier()
default:
return 0, fmt.Errorf("unexpected token %q at position %d", ch, p.pos)
}
}
func (p *calcParser) parseNumber() (float64, error) {
start := p.pos
seenDot := false
seenExp := false
for !p.done() {
ch := p.peek()
switch {
case ch >= '0' && ch <= '9':
p.pos++
case ch == '.' && !seenDot && !seenExp:
seenDot = true
p.pos++
case (ch == 'e') && !seenExp:
seenExp = true
p.pos++
if !p.done() && (p.peek() == '+' || p.peek() == '-') {
p.pos++
}
default:
value, err := strconv.ParseFloat(p.input[start:p.pos], 64)
if err != nil {
return 0, fmt.Errorf("invalid number %q at position %d", p.input[start:p.pos], start)
}
return value, nil
}
}
value, err := strconv.ParseFloat(p.input[start:p.pos], 64)
if err != nil {
return 0, fmt.Errorf("invalid number %q at position %d", p.input[start:p.pos], start)
}
return value, nil
}
func (p *calcParser) parseIdentifier() (float64, error) {
start := p.pos
for !p.done() && isCalcIdent(p.peek()) {
p.pos++
}
name := p.input[start:p.pos]
p.skipSpace()
if p.peek() != '(' {
value, ok := calcConstant(name)
if !ok {
return 0, fmt.Errorf("unknown identifier %q at position %d", name, start)
}
return value, nil
}
p.pos++
args, err := p.parseArguments(name)
if err != nil {
return 0, err
}
return calcFunction(name, args)
}
func (p *calcParser) parseArguments(name string) ([]float64, error) {
p.skipSpace()
if p.peek() == ')' {
p.pos++
return nil, nil
}
var args []float64
for {
arg, err := p.parseExpression()
if err != nil {
return nil, err
}
args = append(args, arg)
p.skipSpace()
if p.peek() != ',' {
break
}
p.pos++
p.skipSpace()
if p.peek() == ')' {
return nil, fmt.Errorf("missing argument for function %q", name)
}
}
if p.peek() != ')' {
return nil, fmt.Errorf("missing ')' after function %q", name)
}
p.pos++
return args, nil
}
func (p *calcParser) skipSpace() {
for !p.done() && unicode.IsSpace(rune(p.peek())) {
p.pos++
}
}
func (p *calcParser) done() bool {
return p.pos >= len(p.input)
}
func (p *calcParser) peek() byte {
if p.done() {
return 0
}
return p.input[p.pos]
}
func isCalcNumberStart(input string, pos int) bool {
ch := input[pos]
if ch >= '0' && ch <= '9' {
return true return true
} }
} return ch == '.' && pos+1 < len(input) && input[pos+1] >= '0' && input[pos+1] <= '9'
return false }
func isCalcIdentStart(ch byte) bool {
return ch >= 'a' && ch <= 'z'
}
func isCalcIdent(ch byte) bool {
return (ch >= 'a' && ch <= 'z') || (ch >= '0' && ch <= '9') || ch == '_'
}
func calcConstant(name string) (float64, bool) {
switch name {
case "pi":
return math.Pi, true
case "e":
return math.E, true
default:
return 0, false
}
}
func calcFunction(name string, args []float64) (float64, error) {
argCount := len(args)
if !calcFunctionArgCountValid(name, argCount) {
return 0, fmt.Errorf("function %q accepts %s, got %d", name, calcFunctionArgSpec(name), argCount)
}
switch name {
case "sin":
return math.Sin(args[0]), nil
case "cos":
return math.Cos(args[0]), nil
case "tan":
return math.Tan(args[0]), nil
case "sinh":
return math.Sinh(args[0]), nil
case "cosh":
return math.Cosh(args[0]), nil
case "tanh":
return math.Tanh(args[0]), nil
case "abs":
return math.Abs(args[0]), nil
case "arcsin", "asin":
return math.Asin(args[0]), nil
case "arccos", "acos":
return math.Acos(args[0]), nil
case "arctan", "atan":
return math.Atan(args[0]), nil
case "sqrt":
return math.Sqrt(args[0]), nil
case "cbrt":
return math.Cbrt(args[0]), nil
case "exp":
return math.Exp(args[0]), nil
case "loge", "ln":
return math.Log(args[0]), nil
case "log":
return math.Log10(args[0]), nil
case "log10":
return math.Log10(args[0]), nil
case "log2":
return math.Log2(args[0]), nil
case "floor":
return math.Floor(args[0]), nil
case "ceil":
return math.Ceil(args[0]), nil
case "round":
return math.Round(args[0]), nil
case "trunc":
return math.Trunc(args[0]), nil
case "rad":
return args[0] * math.Pi / 180.0, nil
case "deg":
return args[0] * 180.0 / math.Pi, nil
case "pow":
return math.Pow(args[0], args[1]), nil
case "hypot":
return math.Hypot(args[0], args[1]), nil
case "min":
result := args[0]
for _, arg := range args[1:] {
if arg < result {
result = arg
}
}
return result, nil
case "max":
result := args[0]
for _, arg := range args[1:] {
if arg > result {
result = arg
}
}
return result, nil
default:
return 0, fmt.Errorf("unknown function %q", name)
}
}
func calcFunctionArgCountValid(name string, count int) bool {
switch name {
case "pow", "hypot":
return count == 2
case "min", "max":
return count >= 1
case "sin", "cos", "tan", "sinh", "cosh", "tanh",
"abs", "arcsin", "asin", "arccos", "acos", "arctan", "atan",
"sqrt", "cbrt", "exp", "loge", "ln", "log", "log10", "log2",
"floor", "ceil", "round", "trunc", "rad", "deg":
return count == 1
default:
return true
}
}
func calcFunctionArgSpec(name string) string {
switch name {
case "pow", "hypot":
return "exactly two arguments"
case "min", "max":
return "at least one argument"
default:
return "exactly one argument"
}
}
func normalizeCalcFloat(value float64) float64 {
text := strconv.FormatFloat(value, 'g', 15, 64)
out, err := strconv.ParseFloat(text, 64)
if err != nil {
return value
}
if out == 0 {
return 0
}
return out
} }

61
math_test.go Normal file
View File

@ -0,0 +1,61 @@
package staros
import (
"math"
"testing"
)
func TestCalcCompatibilityExpressions(t *testing.T) {
tests := []struct {
expr string
want float64
}{
{"1+2*3", 7},
{"(1+2)*3", 9},
{"60*60*24", 86400},
{"-1+2", 1},
{"sqrt(4)+abs(-3)", 5},
{"sin(pi/2)", 1},
{"arcsin(1)", math.Pi / 2},
{"asin(1)", math.Pi / 2},
{"loge(e)", 1},
{"ln(e)", 1},
{"log10(100)+log2(8)", 5},
{"floor(1.9)+ceil(1.1)+round(1.5)+trunc(1.9)", 6},
{"1.2e3+3", 1203},
{"pow(2,3)+hypot(3,4)", 13},
{"min(3,1,2)+max(3,1,2)", 4},
{"log(100)+rad(180)/pi+deg(pi)/180", 4},
{"cbrt(27)+exp(0)", 4},
{"sinh(0)+cosh(0)+tanh(0)", 1},
}
for _, tt := range tests {
got, err := Calc(tt.expr)
if err != nil {
t.Fatalf("Calc(%q) failed: %v", tt.expr, err)
}
if math.Abs(got-tt.want) > 1e-12 {
t.Fatalf("Calc(%q)=%v, want %v", tt.expr, got, tt.want)
}
}
}
func TestCalcRejectsInvalidExpressions(t *testing.T) {
tests := []string{
"",
"1/",
"(1+2",
"1/0",
"unknown(1)",
"min()",
"pow(2)",
"pow(2,3,4)",
"sqrt(1,2)",
"pi()",
}
for _, expr := range tests {
if got, err := Calc(expr); err == nil {
t.Fatalf("Calc(%q)=%v, expected error", expr, got)
}
}
}

View File

@ -1,4 +1,5 @@
//+build darwin //go:build darwin
// +build darwin
package staros package staros
@ -13,7 +14,7 @@ import (
) )
// Memory 系统内存信息 // Memory 系统内存信息
func Memory() (MemStatus,error) { func Memory() (MemStatus, error) {
return darwinMemory() return darwinMemory()
} }

View File

@ -1,8 +1,14 @@
//+build linux //go:build linux
// +build linux
package staros package staros
import "syscall" import (
"io/ioutil"
"strconv"
"strings"
"syscall"
)
// Memory 系统内存信息 // Memory 系统内存信息
func Memory() (MemStatus, error) { func Memory() (MemStatus, error) {
@ -11,14 +17,40 @@ func Memory() (MemStatus, error) {
if err := syscall.Sysinfo(ram); err != nil { if err := syscall.Sysinfo(ram); err != nil {
return mem, err return mem, err
} }
mem.All = uint64(ram.Totalram) unit := uint64(ram.Unit)
mem.BuffCache = uint64(ram.Bufferram) if unit == 0 {
mem.Free = uint64(ram.Freeram) unit = 1
mem.Shared = uint64(ram.Sharedram) }
mem.Available = uint64(ram.Freeram + ram.Sharedram + ram.Bufferram) mem.All = uint64(ram.Totalram) * unit
mem.SwapAll = uint64(ram.Totalswap) mem.BuffCache = uint64(ram.Bufferram) * unit
mem.SwapFree = uint64(ram.Freeswap) mem.Free = uint64(ram.Freeram) * unit
mem.Shared = uint64(ram.Sharedram) * unit
mem.Available = mem.Free + mem.Shared + mem.BuffCache
if available, ok := linuxMemAvailable(); ok {
mem.Available = available
}
mem.SwapAll = uint64(ram.Totalswap) * unit
mem.SwapFree = uint64(ram.Freeswap) * unit
mem.SwapUsed = uint64(mem.SwapAll - mem.SwapFree) mem.SwapUsed = uint64(mem.SwapAll - mem.SwapFree)
mem.Used = uint64(mem.All - mem.Free) mem.Used = uint64(mem.All - mem.Free)
return mem, nil return mem, nil
} }
func linuxMemAvailable() (uint64, bool) {
data, err := ioutil.ReadFile("/proc/meminfo")
if err != nil {
return 0, false
}
for _, line := range strings.Split(string(data), "\n") {
fields := strings.Fields(line)
if len(fields) < 2 || fields[0] != "MemAvailable:" {
continue
}
value, err := strconv.ParseUint(fields[1], 10, 64)
if err != nil {
return 0, false
}
return value * 1024, true
}
return 0, false
}

View File

@ -1,3 +1,4 @@
//go:build windows
// +build windows // +build windows
package staros package staros
@ -19,8 +20,9 @@ func Memory() (MemStatus, error) {
mem.SwapAll = uint64(ram.UllTotalPageFile) mem.SwapAll = uint64(ram.UllTotalPageFile)
mem.SwapFree = uint64(ram.UllAvailPageFile) mem.SwapFree = uint64(ram.UllAvailPageFile)
mem.SwapUsed = mem.SwapAll - mem.SwapFree mem.SwapUsed = mem.SwapAll - mem.SwapFree
mem.VirtualAll = uint64(mem.VirtualAll) mem.VirtualAll = uint64(ram.UllTotalVirtual)
mem.VirtualAvail = uint64(mem.VirtualAvail) mem.VirtualAvail = uint64(ram.UllAvailVirtual)
mem.VirtualUsed = mem.VirtualAll - mem.VirtualUsed mem.VirtualUsed = mem.VirtualAll - mem.VirtualAvail
mem.AvailExtended = uint64(ram.UllAvailExtendedVirtual)
return mem, nil return mem, nil
} }

30
network_darwin.go Normal file
View File

@ -0,0 +1,30 @@
//go:build darwin
// +build darwin
package staros
import "time"
func NetUsage() ([]NetAdapter, error) {
return nil, ERR_UNSUPPORTED
}
func NetUsageByname(name string) (NetAdapter, error) {
return NetAdapter{}, ERR_UNSUPPORTED
}
func NetSpeeds(duration time.Duration) ([]NetSpeed, error) {
return nil, ERR_UNSUPPORTED
}
func NetSpeedsByName(duration time.Duration, name string) (NetSpeed, error) {
return NetSpeed{}, ERR_UNSUPPORTED
}
func NetConnections(analysePid bool, types string) ([]NetConn, error) {
return nil, ERR_UNSUPPORTED
}
func GetInodeMap() (map[string]int64, error) {
return nil, ERR_UNSUPPORTED
}

View File

@ -1,9 +1,67 @@
//go:build linux
// +build linux
package staros package staros
import ( import (
"testing" "testing"
"time"
) )
func Test_TrimSpace(t *testing.T) { func Test_TrimSpace(t *testing.T) {
} }
func TestAnalyseNetFilesSkipsShortLines(t *testing.T) {
data := []byte("sl local_address rem_address st tx_queue rx_queue tr tm->when retrnsmt uid timeout inode\nshort\n")
res, err := analyseNetFiles(data, nil, "tcp")
if err != nil {
t.Fatal(err)
}
if len(res) != 0 {
t.Fatalf("expected no parsed connections, got %d", len(res))
}
}
func TestNetSpeedsRejectsInvalidDuration(t *testing.T) {
_, err := NetSpeeds(0)
if err == nil {
t.Fatal("expected invalid duration error")
}
_, err = NetSpeeds(-time.Second)
if err == nil {
t.Fatal("expected invalid duration error")
}
}
func TestUniqueStrings(t *testing.T) {
got := uniqueStrings([]string{"tcp", "udp", "tcp"})
if len(got) != 2 {
t.Fatalf("expected 2 unique values, got %d", len(got))
}
if got[0] != "tcp" || got[1] != "udp" {
t.Fatalf("unexpected order: %#v", got)
}
}
func TestParseProcStatusKB(t *testing.T) {
if got := parseProcStatusKB("12 kB"); got != 12*1024 {
t.Fatalf("expected 12288, got %d", got)
}
if got := parseProcStatusKB(""); got != 0 {
t.Fatalf("expected 0, got %d", got)
}
}
func TestProcStartTimeFromStatHandlesProcessNameWithSpacesAndParens(t *testing.T) {
stat := []byte("42 (name with ) parens) S 1 1 1 0 -1 4194560 0 0 0 0 0 0 0 0 20 0 1 0 12345")
got, ok := procStartTimeFromStat(stat)
if !ok {
t.Fatal("expected proc stat start time to parse")
}
ticks := int64(clockTicks())
want := time.Unix(StartTime().Unix()+12345/ticks, (12345%ticks)*int64(time.Second)/ticks)
if !got.Equal(want) {
t.Fatalf("unexpected start time: got %s want %s", got, want)
}
}

View File

@ -1,5 +1,5 @@
//go:build !windows //go:build linux
// +build !windows // +build linux
package staros package staros
@ -15,26 +15,37 @@ import (
func NetUsage() ([]NetAdapter, error) { func NetUsage() ([]NetAdapter, error) {
data, err := ioutil.ReadFile("/proc/net/dev") data, err := ioutil.ReadFile("/proc/net/dev")
if err != nil { if err != nil {
return []NetAdapter{}, err return nil, err
} }
sps := strings.Split(strings.TrimSpace(string(data)), "\n") sps := strings.Split(strings.TrimSpace(string(data)), "\n")
if len(sps) < 3 { if len(sps) < 3 {
return []NetAdapter{}, errors.New("No Adaptor") return nil, errors.New("No Adaptor")
} }
var res []NetAdapter var res []NetAdapter
netLists := sps[2:] netLists := sps[2:]
for _, v := range netLists { for _, v := range netLists {
v = strings.ReplaceAll(v, " ", " ") parts := strings.SplitN(strings.TrimSpace(v), ":", 2)
for strings.Contains(v, " ") { if len(parts) != 2 {
v = strings.ReplaceAll(v, " ", " ") continue
}
card := strings.Fields(parts[1])
if len(card) < 16 {
continue
}
name := strings.TrimSpace(parts[0])
recvBytes, err := strconv.ParseUint(card[0], 10, 64)
if err != nil {
continue
}
sendBytes, err := strconv.ParseUint(card[8], 10, 64)
if err != nil {
continue
} }
v = strings.TrimSpace(v)
card := strings.Split(v, " ")
name := strings.ReplaceAll(card[0], ":", "")
recvBytes, _ := strconv.Atoi(card[1])
sendBytes, _ := strconv.Atoi(card[9])
res = append(res, NetAdapter{name, uint64(recvBytes), uint64(sendBytes)}) res = append(res, NetAdapter{name, uint64(recvBytes), uint64(sendBytes)})
} }
if len(res) == 0 {
return nil, errors.New("No Adaptor")
}
return res, nil return res, nil
} }
@ -52,30 +63,48 @@ func NetUsageByname(name string) (NetAdapter, error) {
} }
func NetSpeeds(duration time.Duration) ([]NetSpeed, error) { func NetSpeeds(duration time.Duration) ([]NetSpeed, error) {
if duration <= 0 {
return nil, errors.New("duration must be positive")
}
list1, err := NetUsage() list1, err := NetUsage()
if err != nil { if err != nil {
return []NetSpeed{}, err return nil, err
} }
time.Sleep(duration) time.Sleep(duration)
list2, err := NetUsage() list2, err := NetUsage()
if err != nil { if err != nil {
return []NetSpeed{}, err return nil, err
} }
if len(list1) > len(list2) { byName := make(map[string]NetAdapter, len(list2))
return []NetSpeed{}, errors.New("NetWork Adaptor Num Not ok") for _, item := range list2 {
byName[item.Name] = item
} }
var res []NetSpeed var res []NetSpeed
for k, v := range list1 { for _, v := range list1 {
recv := float64(list2[k].RecvBytes-v.RecvBytes) / duration.Seconds() next, ok := byName[v.Name]
send := float64(list2[k].SendBytes-v.SendBytes) / duration.Seconds() if !ok {
continue
}
var recvDelta, sendDelta uint64
if next.RecvBytes >= v.RecvBytes {
recvDelta = next.RecvBytes - v.RecvBytes
}
if next.SendBytes >= v.SendBytes {
sendDelta = next.SendBytes - v.SendBytes
}
recv := float64(recvDelta) / duration.Seconds()
send := float64(sendDelta) / duration.Seconds()
res = append(res, NetSpeed{ res = append(res, NetSpeed{
Name: v.Name, Name: v.Name,
RecvSpeeds: recv, RecvSpeeds: recv,
SendSpeeds: send, SendSpeeds: send,
RecvBytes: list2[k].RecvBytes, RecvBytes: next.RecvBytes,
SendBytes: list2[k].SendBytes, SendBytes: next.SendBytes,
}) })
} }
if len(res) == 0 {
return nil, errors.New("NetWork Adaptor Num Not ok")
}
return res, nil return res, nil
} }
@ -99,7 +128,8 @@ func NetConnections(analysePid bool, types string) ([]NetConn, error) {
var inodeMap map[string]int64 var inodeMap map[string]int64
var err error var err error
var fileList []string var fileList []string
if types == "" || strings.Contains(strings.ToLower(types), "all") { types = strings.ToLower(types)
if types == "" || strings.Contains(types, "all") {
fileList = []string{ fileList = []string{
"/proc/net/tcp", "/proc/net/tcp",
"/proc/net/tcp6", "/proc/net/tcp6",
@ -107,25 +137,33 @@ func NetConnections(analysePid bool, types string) ([]NetConn, error) {
"/proc/net/udp6", "/proc/net/udp6",
"/proc/net/unix", "/proc/net/unix",
} }
} } else {
if strings.Contains(strings.ToLower(types), "tcp") { if strings.Contains(types, "tcp") {
fileList = append(fileList, "/proc/net/tcp", "/proc/net/tcp6") fileList = append(fileList, "/proc/net/tcp", "/proc/net/tcp6")
} }
if strings.Contains(strings.ToLower(types), "udp") { if strings.Contains(types, "udp") {
fileList = append(fileList, "/proc/net/udp", "/proc/net/udp6") fileList = append(fileList, "/proc/net/udp", "/proc/net/udp6")
} }
if strings.Contains(strings.ToLower(types), "unix") { if strings.Contains(types, "unix") {
fileList = append(fileList, "/proc/net/unix") fileList = append(fileList, "/proc/net/unix")
} }
}
fileList = uniqueStrings(fileList)
if len(fileList) == 0 {
return nil, errors.New("unsupported net connection type")
}
if analysePid { if analysePid {
inodeMap, err = GetInodeMap() inodeMap, err = GetInodeMap()
if err != nil { if err != nil {
return result, err inodeMap = nil
} }
} }
for _, file := range fileList { for _, file := range fileList {
data, err := ioutil.ReadFile(file) data, err := ioutil.ReadFile(file)
if err != nil { if err != nil {
if os.IsNotExist(err) {
continue
}
return result, err return result, err
} }
tmpRes, err := analyseNetFiles(data, inodeMap, file[strings.LastIndex(file, "/")+1:]) tmpRes, err := analyseNetFiles(data, inodeMap, file[strings.LastIndex(file, "/")+1:])
@ -137,6 +175,22 @@ func NetConnections(analysePid bool, types string) ([]NetConn, error) {
return result, nil return result, nil
} }
func uniqueStrings(items []string) []string {
if len(items) == 0 {
return nil
}
seen := make(map[string]struct{}, len(items))
res := make([]string, 0, len(items))
for _, item := range items {
if _, ok := seen[item]; ok {
continue
}
seen[item] = struct{}{}
res = append(res, item)
}
return res
}
func GetInodeMap() (map[string]int64, error) { func GetInodeMap() (map[string]int64, error) {
res := make(map[string]int64) res := make(map[string]int64)
paths, err := ioutil.ReadDir("/proc") paths, err := ioutil.ReadDir("/proc")
@ -186,6 +240,9 @@ func analyseNetFiles(data []byte, inodeMap map[string]int64, typed string) ([]Ne
continue continue
} }
v := strings.Split(strings.TrimSpace(lineData), " ") v := strings.Split(strings.TrimSpace(lineData), " ")
if len(v) < 10 {
continue
}
var res NetConn var res NetConn
ip, port, err := parseHexIpPort(v[1]) ip, port, err := parseHexIpPort(v[1])
if err != nil { if err != nil {
@ -205,7 +262,11 @@ func analyseNetFiles(data []byte, inodeMap map[string]int64, typed string) ([]Ne
if err != nil { if err != nil {
return result, err return result, err
} }
if state >= 0 && int(state) < len(TCP_STATE) {
res.Status = TCP_STATE[state] res.Status = TCP_STATE[state]
} else {
res.Status = TCP_STATE[TCP_UNKNOWN]
}
} }
txrx_queue := strings.Split(strings.TrimSpace(v[4]), ":") txrx_queue := strings.Split(strings.TrimSpace(v[4]), ":")
if len(txrx_queue) != 2 { if len(txrx_queue) != 2 {
@ -293,6 +354,9 @@ func analyseUnixFiles(data []byte, inodeMap map[string]int64, typed string) ([]N
continue continue
} }
v := strings.Split(strings.TrimSpace(lineData), " ") v := strings.Split(strings.TrimSpace(lineData), " ")
if len(v) < 7 {
continue
}
var res NetConn var res NetConn
res.Inode = v[6] res.Inode = v[6]
if len(v) == 8 { if len(v) == 8 {

View File

@ -1,33 +1,314 @@
//go:build windows
// +build windows // +build windows
package staros package staros
import ( import (
"errors"
"net"
"strconv"
"strings"
"syscall"
"time" "time"
"b612.me/win32api"
) )
const windowsErrorNotSupported syscall.Errno = 50
func NetUsage() ([]NetAdapter, error) { func NetUsage() ([]NetAdapter, error) {
var res []NetAdapter rows, err := win32api.GetIfTable2()
if err != nil {
return nil, err
}
res := make([]NetAdapter, 0, len(rows))
for _, row := range rows {
name := windowsInterfaceName(row)
if name == "" {
continue
}
res = append(res, NetAdapter{
Name: name,
RecvBytes: row.InOctets,
SendBytes: row.OutOctets,
})
}
if len(res) == 0 {
return nil, errors.New("No Adaptor")
}
return res, nil return res, nil
} }
func NetUsageByname(name string) (NetAdapter, error) { func NetUsageByname(name string) (NetAdapter, error) {
return NetAdapter{}, nil ada, err := NetUsage()
if err != nil {
return NetAdapter{}, err
}
for _, v := range ada {
if v.Name == name {
return v, nil
}
}
return NetAdapter{}, errors.New("Not Found")
} }
func NetSpeeds(duration time.Duration) ([]NetSpeed, error) { func NetSpeeds(duration time.Duration) ([]NetSpeed, error) {
var res []NetSpeed if duration <= 0 {
return nil, errors.New("duration must be positive")
}
list1, err := NetUsage()
if err != nil {
return nil, err
}
time.Sleep(duration)
list2, err := NetUsage()
if err != nil {
return nil, err
}
byName := make(map[string]NetAdapter, len(list2))
for _, item := range list2 {
byName[item.Name] = item
}
res := make([]NetSpeed, 0, len(list1))
for _, v := range list1 {
next, ok := byName[v.Name]
if !ok {
continue
}
var recvDelta, sendDelta uint64
if next.RecvBytes >= v.RecvBytes {
recvDelta = next.RecvBytes - v.RecvBytes
}
if next.SendBytes >= v.SendBytes {
sendDelta = next.SendBytes - v.SendBytes
}
res = append(res, NetSpeed{
Name: v.Name,
RecvSpeeds: float64(recvDelta) / duration.Seconds(),
SendSpeeds: float64(sendDelta) / duration.Seconds(),
RecvBytes: next.RecvBytes,
SendBytes: next.SendBytes,
})
}
if len(res) == 0 {
return nil, errors.New("NetWork Adaptor Num Not ok")
}
return res, nil return res, nil
} }
func NetSpeedsByName(duration time.Duration, name string) (NetSpeed, error) { func NetSpeedsByName(duration time.Duration, name string) (NetSpeed, error) {
ada, err := NetSpeeds(duration)
return NetSpeed{}, nil if err != nil {
return NetSpeed{}, err
}
for _, v := range ada {
if v.Name == name {
return v, nil
}
}
return NetSpeed{}, errors.New("Not Found")
} }
// NetConnections return all TCP/UDP/UNIX DOMAIN SOCKET Connections // NetConnections return all TCP/UDP/UNIX DOMAIN SOCKET Connections
// if your uid != 0 ,and analysePid==true ,you should have CAP_SYS_PRTACE and CAP_DAC_OVERRIDE/CAP_DAC_READ_SEARCH Caps // if your uid != 0 ,and analysePid==true ,you should have CAP_SYS_PRTACE and CAP_DAC_OVERRIDE/CAP_DAC_READ_SEARCH Caps
func NetConnections(analysePid bool) ([]NetConn, error) { func NetConnections(analysePid bool, types string) ([]NetConn, error) {
var result []NetConn wantTCP, wantUDP, err := windowsNetConnectionTypes(types)
if err != nil {
return nil, err
}
result := make([]NetConn, 0)
processCache := make(map[int64]*Process)
if wantTCP {
result, err = appendWindowsTCPConnections(result, analysePid, processCache)
if err != nil {
return result, err
}
}
if wantUDP {
result, err = appendWindowsUDPConnections(result, analysePid, processCache)
if err != nil {
return result, err
}
}
return result, nil return result, nil
} }
func GetInodeMap() (map[string]int64, error) {
return nil, ERR_UNSUPPORTED
}
func windowsInterfaceName(row win32api.MIB_IF_ROW2) string {
name := strings.TrimSpace(syscall.UTF16ToString(row.Alias[:]))
if name != "" {
return name
}
name = strings.TrimSpace(syscall.UTF16ToString(row.Description[:]))
if name != "" {
return name
}
if row.InterfaceIndex != 0 {
return "if" + strconv.FormatUint(uint64(row.InterfaceIndex), 10)
}
if row.InterfaceLuid != 0 {
return "luid" + strconv.FormatUint(row.InterfaceLuid, 10)
}
return ""
}
func windowsNetConnectionTypes(types string) (wantTCP, wantUDP bool, err error) {
normalized := strings.ToLower(strings.TrimSpace(types))
if strings.Contains(normalized, "unix") {
return false, false, ERR_UNSUPPORTED
}
if normalized == "" || strings.Contains(normalized, "all") {
return true, true, nil
}
if strings.Contains(normalized, "tcp") {
wantTCP = true
}
if strings.Contains(normalized, "udp") {
wantUDP = true
}
if !wantTCP && !wantUDP {
return false, false, errors.New("unsupported net connection type")
}
return wantTCP, wantUDP, nil
}
func appendWindowsTCPConnections(result []NetConn, analysePid bool, processCache map[int64]*Process) ([]NetConn, error) {
rows4, err := win32api.GetExtendedTcp4Table(false, win32api.TCP_TABLE_OWNER_PID_ALL)
if err != nil {
return result, err
}
for _, row := range rows4 {
conn := NetConn{
LocalAddr: windowsIPv4FromDWORD(row.LocalAddr),
LocalPort: int(row.LocalPortHost()),
RemoteAddr: windowsIPv4FromDWORD(row.RemoteAddr),
RemotePort: int(row.RemotePortHost()),
Status: windowsTCPState(row.State),
Typed: "tcp",
}
attachWindowsProcess(&conn, row.OwningPid, analysePid, processCache)
result = append(result, conn)
}
rows6, err := win32api.GetExtendedTcp6Table(false, win32api.TCP_TABLE_OWNER_PID_ALL)
if err != nil {
if isOptionalWindowsNetTableError(err) {
return result, nil
}
return result, err
}
for _, row := range rows6 {
conn := NetConn{
LocalAddr: net.IP(row.LocalAddr[:]).String(),
LocalPort: int(row.LocalPortHost()),
RemoteAddr: net.IP(row.RemoteAddr[:]).String(),
RemotePort: int(row.RemotePortHost()),
Status: windowsTCPState(row.State),
Typed: "tcp6",
}
attachWindowsProcess(&conn, row.OwningPid, analysePid, processCache)
result = append(result, conn)
}
return result, nil
}
func appendWindowsUDPConnections(result []NetConn, analysePid bool, processCache map[int64]*Process) ([]NetConn, error) {
rows4, err := win32api.GetExtendedUdp4Table(false, win32api.UDP_TABLE_OWNER_PID)
if err != nil {
return result, err
}
for _, row := range rows4 {
conn := NetConn{
LocalAddr: windowsIPv4FromDWORD(row.LocalAddr),
LocalPort: int(row.LocalPortHost()),
Typed: "udp",
}
attachWindowsProcess(&conn, row.OwningPid, analysePid, processCache)
result = append(result, conn)
}
rows6, err := win32api.GetExtendedUdp6Table(false, win32api.UDP_TABLE_OWNER_PID)
if err != nil {
if isOptionalWindowsNetTableError(err) {
return result, nil
}
return result, err
}
for _, row := range rows6 {
conn := NetConn{
LocalAddr: net.IP(row.LocalAddr[:]).String(),
LocalPort: int(row.LocalPortHost()),
Typed: "udp6",
}
attachWindowsProcess(&conn, row.OwningPid, analysePid, processCache)
result = append(result, conn)
}
return result, nil
}
func attachWindowsProcess(conn *NetConn, pid uint32, analysePid bool, processCache map[int64]*Process) {
if conn == nil || !analysePid {
return
}
conn.Pid = int64(pid)
if conn.Pid <= 0 {
return
}
if proc, ok := processCache[conn.Pid]; ok {
conn.Process = proc
return
}
proc, err := FindProcessByPid(conn.Pid)
if err != nil {
processCache[conn.Pid] = nil
return
}
processCache[conn.Pid] = &proc
conn.Process = &proc
}
func windowsIPv4FromDWORD(addr uint32) string {
return net.IPv4(byte(addr), byte(addr>>8), byte(addr>>16), byte(addr>>24)).String()
}
func windowsTCPState(state win32api.MIB_TCP_STATE) string {
switch state {
case win32api.MIB_TCP_STATE_CLOSED:
return TCP_STATE[TCP_CLOSE]
case win32api.MIB_TCP_STATE_LISTEN:
return TCP_STATE[TCP_LISTEN]
case win32api.MIB_TCP_STATE_SYN_SENT:
return TCP_STATE[TCP_SYN_SENT]
case win32api.MIB_TCP_STATE_SYN_RCVD:
return TCP_STATE[TCP_SYN_RECV]
case win32api.MIB_TCP_STATE_ESTAB:
return TCP_STATE[TCP_ESTABLISHED]
case win32api.MIB_TCP_STATE_FIN_WAIT1:
return TCP_STATE[TCP_FIN_WAIT1]
case win32api.MIB_TCP_STATE_FIN_WAIT2:
return TCP_STATE[TCP_FIN_WAIT2]
case win32api.MIB_TCP_STATE_CLOSE_WAIT:
return TCP_STATE[TCP_CLOSE_WAIT]
case win32api.MIB_TCP_STATE_CLOSING:
return TCP_STATE[TCP_CLOSING]
case win32api.MIB_TCP_STATE_LAST_ACK:
return TCP_STATE[TCP_LAST_ACK]
case win32api.MIB_TCP_STATE_TIME_WAIT:
return TCP_STATE[TCP_TIME_WAIT]
case win32api.MIB_TCP_STATE_DELETE_TCB:
return "TCP_DELETE_TCB"
default:
return TCP_STATE[TCP_UNKNOWN]
}
}
func isOptionalWindowsNetTableError(err error) bool {
if errno, ok := err.(syscall.Errno); ok {
return errno == windowsErrorNotSupported
}
return false
}

121
network_windows_test.go Normal file
View File

@ -0,0 +1,121 @@
//go:build windows
// +build windows
package staros
import (
"errors"
"syscall"
"testing"
"b612.me/win32api"
)
func TestWindowsNetConnectionTypes(t *testing.T) {
tcp, udp, err := windowsNetConnectionTypes("")
if err != nil {
t.Fatal(err)
}
if !tcp || !udp {
t.Fatalf("empty types should request tcp and udp, got tcp=%v udp=%v", tcp, udp)
}
tcp, udp, err = windowsNetConnectionTypes("all")
if err != nil {
t.Fatal(err)
}
if !tcp || !udp {
t.Fatalf("all types should request tcp and udp, got tcp=%v udp=%v", tcp, udp)
}
tcp, udp, err = windowsNetConnectionTypes("tcp")
if err != nil {
t.Fatal(err)
}
if !tcp || udp {
t.Fatalf("tcp types mismatch: tcp=%v udp=%v", tcp, udp)
}
tcp, udp, err = windowsNetConnectionTypes("TCP,UDP")
if err != nil {
t.Fatal(err)
}
if !tcp || !udp {
t.Fatalf("mixed tcp/udp types mismatch: tcp=%v udp=%v", tcp, udp)
}
tcp, udp, err = windowsNetConnectionTypes("udp")
if err != nil {
t.Fatal(err)
}
if tcp || !udp {
t.Fatalf("udp types mismatch: tcp=%v udp=%v", tcp, udp)
}
if _, _, err = windowsNetConnectionTypes("unix"); !errors.Is(err, ERR_UNSUPPORTED) {
t.Fatalf("unix should be unsupported on windows, got %v", err)
}
if _, _, err = windowsNetConnectionTypes("tcp,unix"); !errors.Is(err, ERR_UNSUPPORTED) {
t.Fatalf("mixed unix request should be unsupported on windows, got %v", err)
}
if _, _, err = windowsNetConnectionTypes("all,unix"); !errors.Is(err, ERR_UNSUPPORTED) {
t.Fatalf("all plus unix request should be unsupported on windows, got %v", err)
}
if _, _, err = windowsNetConnectionTypes("raw"); err == nil {
t.Fatal("unknown type should return error")
}
}
func TestWindowsIPv4FromDWORD(t *testing.T) {
if got := windowsIPv4FromDWORD(0x0100007f); got != "127.0.0.1" {
t.Fatalf("unexpected localhost conversion: %s", got)
}
}
func TestWindowsTCPState(t *testing.T) {
cases := map[win32api.MIB_TCP_STATE]string{
win32api.MIB_TCP_STATE_ESTAB: TCP_STATE[TCP_ESTABLISHED],
win32api.MIB_TCP_STATE_LISTEN: TCP_STATE[TCP_LISTEN],
win32api.MIB_TCP_STATE_SYN_SENT: TCP_STATE[TCP_SYN_SENT],
win32api.MIB_TCP_STATE_SYN_RCVD: TCP_STATE[TCP_SYN_RECV],
win32api.MIB_TCP_STATE_FIN_WAIT1: TCP_STATE[TCP_FIN_WAIT1],
win32api.MIB_TCP_STATE_FIN_WAIT2: TCP_STATE[TCP_FIN_WAIT2],
win32api.MIB_TCP_STATE_TIME_WAIT: TCP_STATE[TCP_TIME_WAIT],
win32api.MIB_TCP_STATE_CLOSED: TCP_STATE[TCP_CLOSE],
win32api.MIB_TCP_STATE_CLOSE_WAIT: TCP_STATE[TCP_CLOSE_WAIT],
win32api.MIB_TCP_STATE_LAST_ACK: TCP_STATE[TCP_LAST_ACK],
win32api.MIB_TCP_STATE_CLOSING: TCP_STATE[TCP_CLOSING],
}
for state, want := range cases {
if got := windowsTCPState(state); got != want {
t.Fatalf("state %d mismatch: got=%s want=%s", state, got, want)
}
}
if got := windowsTCPState(win32api.MIB_TCP_STATE(0)); got != TCP_STATE[TCP_UNKNOWN] {
t.Fatalf("unknown state mismatch: %s", got)
}
}
func TestIsOptionalWindowsNetTableError(t *testing.T) {
if !isOptionalWindowsNetTableError(windowsErrorNotSupported) {
t.Fatal("ERROR_NOT_SUPPORTED should be optional")
}
if isOptionalWindowsNetTableError(syscall.EINVAL) {
t.Fatal("EINVAL should not be optional")
}
}
func TestAttachWindowsProcess(t *testing.T) {
conn := NetConn{}
cache := map[int64]*Process{}
attachWindowsProcess(&conn, 123, false, cache)
if conn.Pid != 0 || conn.Process != nil {
t.Fatalf("analysePid=false should not populate process fields: %#v", conn)
}
conn = NetConn{}
attachWindowsProcess(&conn, 0, true, cache)
if conn.Pid != 0 || conn.Process != nil {
t.Fatalf("pid 0 should not populate process fields: %#v", conn)
}
}

46
os.go
View File

@ -1,19 +1,34 @@
package staros package staros
import ( import (
"fmt"
"os/user" "os/user"
"strconv" "strconv"
) )
func parseUint32Identity(kind, raw string) (uint32, error) {
value, err := strconv.ParseUint(raw, 10, 32)
if err != nil {
return 0, fmt.Errorf("parse %s %q: %w", kind, raw, err)
}
return uint32(value), nil
}
// GetUidGid // GetUidGid
func GetUidGid(uname string) (uint32, uint32, string, error) { func GetUidGid(uname string) (uint32, uint32, string, error) {
usr, err := user.Lookup(uname) usr, err := user.Lookup(uname)
if err != nil { if err != nil {
return 0, 0, "", err return 0, 0, "", err
} }
uidInt, _ := strconv.Atoi(usr.Uid) uid, err := parseUint32Identity("uid", usr.Uid)
gidInt, _ := strconv.Atoi(usr.Gid) if err != nil {
return uint32(uidInt), uint32(gidInt), usr.HomeDir, nil return 0, 0, "", err
}
gid, err := parseUint32Identity("gid", usr.Gid)
if err != nil {
return 0, 0, "", err
}
return uid, gid, usr.HomeDir, nil
} }
// GetUid // GetUid
@ -22,26 +37,23 @@ func GetUid(uname string) (uint32, error) {
if err != nil { if err != nil {
return 0, err return 0, err
} }
uidInt, _ := strconv.Atoi(usr.Uid) return parseUint32Identity("uid", usr.Uid)
return uint32(uidInt), nil
} }
// GetGid // GetGid
func GetGid(uname string) (uint32, error) { func GetGid(uname string) (uint32, error) {
usr, err := user.LookupGroup(uname)
if err != nil {
return 0, err
}
gidInt, _ := strconv.Atoi(usr.Gid)
return uint32(gidInt), nil
}
// GetGidByName
func GetGidByName(uname string) (uint32, error) {
usr, err := user.Lookup(uname) usr, err := user.Lookup(uname)
if err != nil { if err != nil {
return 0, err return 0, err
} }
uidInt, _ := strconv.Atoi(usr.Gid) return parseUint32Identity("gid", usr.Gid)
return uint32(uidInt), nil }
// GetGidByName
func GetGidByName(uname string) (uint32, error) {
usr, err := user.LookupGroup(uname)
if err != nil {
return 0, err
}
return parseUint32Identity("gid", usr.Gid)
} }

69
os_darwin.go Normal file
View File

@ -0,0 +1,69 @@
//go:build darwin
// +build darwin
package staros
import (
"os/user"
"strconv"
"syscall"
"time"
)
// StartTime is not implemented on Darwin yet.
func StartTime() time.Time {
return time.Time{}
}
// IsRoot 当前是否是管理员用户
func IsRoot() bool {
uid, err := user.Current()
return err == nil && uid.Uid == "0"
}
func Whoami() (uid, gid int, uname, gname, home string, err error) {
var me *user.User
var group *user.Group
me, err = user.Current()
if err != nil {
return
}
uid, _ = strconv.Atoi(me.Uid)
gid, _ = strconv.Atoi(me.Gid)
home = me.HomeDir
uname = me.Username
group, err = user.LookupGroupId(me.Gid)
if err != nil {
return
}
gname = group.Name
return
}
func CpuUsageByPid(pid int, sleep time.Duration) float64 {
return 0
}
func CpuUsage(sleep time.Duration) float64 {
return 0
}
func DiskUsage(path string) (disk DiskStatus) {
disk, _ = DiskUsageE(path)
return
}
func DiskUsageE(path string) (disk DiskStatus, err error) {
if path == "" {
path = "."
}
fs := syscall.Statfs_t{}
if err = syscall.Statfs(path, &fs); err != nil {
return
}
disk.All = fs.Blocks * uint64(fs.Bsize)
disk.Free = fs.Bfree * uint64(fs.Bsize)
disk.Available = fs.Bavail * uint64(fs.Bsize)
disk.Used = disk.All - disk.Free
return
}

View File

@ -1,11 +1,110 @@
package staros package staros
import ( import (
"fmt" "errors"
"os/user"
"strconv"
"testing" "testing"
"time"
) )
func Test_Disk(t *testing.T) { func Test_Disk(t *testing.T) {
a := DiskUsage("c:") disk, err := DiskUsageE(".")
fmt.Println(a) if err != nil {
t.Fatal(err)
}
if disk.All == 0 {
t.Fatal("expected non-zero total disk size")
}
if disk.Used+disk.Free != disk.All {
t.Fatalf("expected used + free == all, got used=%d free=%d all=%d", disk.Used, disk.Free, disk.All)
}
}
func TestCpuUsageDoesNotPanic(t *testing.T) {
_ = CpuUsage(time.Millisecond)
}
func TestWhoamiGID(t *testing.T) {
_, gid, _, _, _, err := Whoami()
if errors.Is(err, ERR_UNSUPPORTED) {
t.Skip(err)
}
if err != nil {
t.Fatal(err)
}
current, err := user.Current()
if err != nil {
t.Fatal(err)
}
expected, err := strconv.Atoi(current.Gid)
if err != nil {
t.Fatal(err)
}
if gid != expected {
t.Fatalf("expected gid %d, got %d", expected, gid)
}
}
func TestIdentityLookupFunctions(t *testing.T) {
current, err := user.Current()
if err != nil {
t.Skipf("user.Current unavailable: %v", err)
}
wantUID, uidErr := strconv.ParseUint(current.Uid, 10, 32)
wantGID, gidErr := strconv.ParseUint(current.Gid, 10, 32)
uid, gid, home, err := GetUidGid(current.Username)
if uidErr == nil && gidErr == nil {
if err != nil {
t.Fatalf("GetUidGid failed: %v", err)
}
if uid != uint32(wantUID) || gid != uint32(wantGID) || home != current.HomeDir {
t.Fatalf("GetUidGid mismatch: uid=%d gid=%d home=%q", uid, gid, home)
}
} else if err == nil {
t.Fatalf("GetUidGid should reject non-numeric ids: uid=%q gid=%q", current.Uid, current.Gid)
}
uid, err = GetUid(current.Username)
if uidErr == nil {
if err != nil {
t.Fatalf("GetUid failed: %v", err)
}
if uid != uint32(wantUID) {
t.Fatalf("GetUid mismatch: got=%d want=%d", uid, wantUID)
}
} else if err == nil {
t.Fatalf("GetUid should reject non-numeric uid %q", current.Uid)
}
gid, err = GetGid(current.Username)
if gidErr == nil {
if err != nil {
t.Fatalf("GetGid failed: %v", err)
}
if gid != uint32(wantGID) {
t.Fatalf("GetGid mismatch: got=%d want=%d", gid, wantGID)
}
} else if err == nil {
t.Fatalf("GetGid should reject non-numeric gid %q", current.Gid)
}
group, err := user.LookupGroupId(current.Gid)
if err != nil {
t.Skipf("user.LookupGroupId unavailable: %v", err)
}
groupID, groupErr := strconv.ParseUint(group.Gid, 10, 32)
gotGroupID, err := GetGidByName(group.Name)
if groupErr == nil {
if err != nil {
t.Fatalf("GetGidByName failed: %v", err)
}
if gotGroupID != uint32(groupID) {
t.Fatalf("GetGidByName mismatch: got=%d want=%d", gotGroupID, groupID)
}
} else if err == nil {
t.Fatalf("GetGidByName should reject non-numeric gid %q", group.Gid)
}
} }

View File

@ -1,19 +1,27 @@
// +build linux darwin unix //go:build linux
// +build linux
package staros package staros
import ( import (
"bytes" "bytes"
"fmt" "encoding/binary"
"errors"
"io/ioutil" "io/ioutil"
"os"
"os/user" "os/user"
"strconv" "strconv"
"strings" "strings"
"sync"
"syscall" "syscall"
"time" "time"
"unsafe"
) )
var clockTicks = 100 // default value var (
clockTicksOnce sync.Once
clockTicksValue uint64 = 100
)
// StartTime 开机时间 // StartTime 开机时间
func StartTime() time.Time { func StartTime() time.Time {
@ -25,11 +33,8 @@ func StartTime() time.Time {
// IsRoot 当前是否是管理员用户 // IsRoot 当前是否是管理员用户
func IsRoot() bool { func IsRoot() bool {
uid, _ := user.Current() uid, err := user.Current()
if uid.Uid == "0" { return err == nil && uid != nil && uid.Uid == "0"
return true
}
return false
} }
func Whoami() (uid, gid int, uname, gname, home string, err error) { func Whoami() (uid, gid int, uname, gname, home string, err error) {
@ -40,7 +45,7 @@ func Whoami() (uid, gid int, uname, gname, home string, err error) {
return return
} }
uid, _ = strconv.Atoi(me.Uid) uid, _ = strconv.Atoi(me.Uid)
gid, _ = strconv.Atoi(me.Uid) gid, _ = strconv.Atoi(me.Gid)
home = me.HomeDir home = me.HomeDir
uname = me.Username uname = me.Username
gup, err = user.LookupGroupId(me.Gid) gup, err = user.LookupGroupId(me.Gid)
@ -51,6 +56,79 @@ func Whoami() (uid, gid int, uname, gname, home string, err error) {
return return
} }
func clockTicks() uint64 {
clockTicksOnce.Do(initClockTicks)
if clockTicksValue == 0 {
return 100
}
return clockTicksValue
}
func initClockTicks() {
ticks, err := readClockTicksFromAuxv()
if err != nil || ticks == 0 {
return
}
clockTicksValue = ticks
}
func readClockTicksFromAuxv() (uint64, error) {
data, err := os.ReadFile("/proc/self/auxv")
if err != nil {
return 0, err
}
wordSize := int(unsafe.Sizeof(uintptr(0)))
if wordSize != 4 && wordSize != 8 {
return 0, errors.New("unsupported pointer size")
}
order := nativeEndian()
entrySize := wordSize * 2
for offset := 0; offset+entrySize <= len(data); offset += entrySize {
key := readAuxvWord(data[offset:offset+wordSize], order)
val := readAuxvWord(data[offset+wordSize:offset+entrySize], order)
if key == 0 {
break
}
if key == 17 {
return val, nil
}
}
return 0, errors.New("AT_CLKTCK not found")
}
func readAuxvWord(data []byte, order binary.ByteOrder) uint64 {
if len(data) == 4 {
return uint64(order.Uint32(data))
}
return order.Uint64(data)
}
func nativeEndian() binary.ByteOrder {
var n uint16 = 1
if *(*byte)(unsafe.Pointer(&n)) == 1 {
return binary.LittleEndian
}
return binary.BigEndian
}
func cpuUsageOverDuration(delta float64, sleep time.Duration) float64 {
if delta < 0 || sleep <= 0 {
return 0
}
seconds := sleep.Seconds()
if seconds <= 0 {
return 0
}
return delta / seconds * 100
}
func cpuUsagePercent(busyTicks, totalTicks float64) float64 {
if busyTicks < 0 || totalTicks <= 0 {
return 0
}
return 100 * busyTicks / totalTicks
}
func getCPUSample() (idle, total uint64) { func getCPUSample() (idle, total uint64) {
contents, err := ioutil.ReadFile("/proc/stat") contents, err := ioutil.ReadFile("/proc/stat")
if err != nil { if err != nil {
@ -59,12 +137,15 @@ func getCPUSample() (idle, total uint64) {
lines := strings.Split(string(contents), "\n") lines := strings.Split(string(contents), "\n")
for _, line := range lines { for _, line := range lines {
fields := strings.Fields(line) fields := strings.Fields(line)
if len(fields) == 0 {
continue
}
if fields[0] == "cpu" { if fields[0] == "cpu" {
numFields := len(fields) numFields := len(fields)
for i := 1; i < numFields; i++ { for i := 1; i < numFields; i++ {
val, err := strconv.ParseUint(fields[i], 10, 64) val, err := strconv.ParseUint(fields[i], 10, 64)
if err != nil { if err != nil {
fmt.Println("Error: ", i, fields[i], err) continue
} }
total += val // tally up all the numbers to get total ticks total += val // tally up all the numbers to get total ticks
if i == 4 || i == 5 { // idle is the 5th field in the cpu line if i == 4 || i == 5 { // idle is the 5th field in the cpu line
@ -117,31 +198,45 @@ func getCPUSampleByPid(pid int) float64 {
} else { } else {
iotime = 0 // e.g. SmartOS containers iotime = 0 // e.g. SmartOS containers
} }
return utime/float64(clockTicks) + stime/float64(clockTicks) + iotime/float64(clockTicks) ticks := float64(clockTicks())
return utime/ticks + stime/ticks + iotime/ticks
} }
func CpuUsageByPid(pid int, sleep time.Duration) float64 { func CpuUsageByPid(pid int, sleep time.Duration) float64 {
if sleep <= 0 {
return 0
}
total1 := getCPUSampleByPid(pid) total1 := getCPUSampleByPid(pid)
time.Sleep(sleep) time.Sleep(sleep)
total2 := getCPUSampleByPid(pid) total2 := getCPUSampleByPid(pid)
return (total2 - total1) / sleep.Seconds() * 100 return cpuUsageOverDuration(total2-total1, sleep)
} }
// CpuUsage 获取CPU使用量 // CpuUsage 获取CPU使用量
func CpuUsage(sleep time.Duration) float64 { func CpuUsage(sleep time.Duration) float64 {
if sleep <= 0 {
return 0
}
idle0, total0 := getCPUSample() idle0, total0 := getCPUSample()
time.Sleep(sleep) time.Sleep(sleep)
idle1, total1 := getCPUSample() idle1, total1 := getCPUSample()
idleTicks := float64(idle1 - idle0) idleTicks := float64(idle1 - idle0)
totalTicks := float64(total1 - total0) totalTicks := float64(total1 - total0)
cpuUsage := 100 * (totalTicks - idleTicks) / totalTicks cpuUsage := cpuUsagePercent(totalTicks-idleTicks, totalTicks)
return cpuUsage return cpuUsage
//fmt.Printf("CPU usage is %f%% [busy: %f, total: %f]\n", cpuUsage, totalTicks-idleTicks, totalTicks) //fmt.Printf("CPU usage is %f%% [busy: %f, total: %f]\n", cpuUsage, totalTicks-idleTicks, totalTicks)
} }
func DiskUsage(path string) (disk DiskStatus) { func DiskUsage(path string) (disk DiskStatus) {
disk, _ = DiskUsageE(path)
return
}
func DiskUsageE(path string) (disk DiskStatus, err error) {
if path == "" {
path = "."
}
fs := syscall.Statfs_t{} fs := syscall.Statfs_t{}
err := syscall.Statfs(path, &fs) if err = syscall.Statfs(path, &fs); err != nil {
if err != nil {
return return
} }
disk.All = fs.Blocks * uint64(fs.Bsize) disk.All = fs.Blocks * uint64(fs.Bsize)

36
os_unix_test.go Normal file
View File

@ -0,0 +1,36 @@
//go:build linux
// +build linux
package staros
import (
"testing"
"time"
)
func TestCPUUsageOverDurationGuardsZeroOrNegativeWindow(t *testing.T) {
if got := cpuUsageOverDuration(10, 0); got != 0 {
t.Fatalf("expected zero-window cpu usage to clamp to 0, got %v", got)
}
if got := cpuUsageOverDuration(10, -time.Millisecond); got != 0 {
t.Fatalf("expected negative-window cpu usage to clamp to 0, got %v", got)
}
if got := cpuUsageOverDuration(-1, time.Second); got != 0 {
t.Fatalf("expected negative delta cpu usage to clamp to 0, got %v", got)
}
if got := cpuUsageOverDuration(0.5, time.Second); got != 50 {
t.Fatalf("expected normal cpu usage calculation, got %v", got)
}
}
func TestCPUUsagePercentGuardsInvalidSamples(t *testing.T) {
if got := cpuUsagePercent(1, 0); got != 0 {
t.Fatalf("expected zero total ticks to clamp to 0, got %v", got)
}
if got := cpuUsagePercent(-1, 4); got != 0 {
t.Fatalf("expected negative busy ticks to clamp to 0, got %v", got)
}
if got := cpuUsagePercent(1, 4); got != 25 {
t.Fatalf("expected normal cpu percent calculation, got %v", got)
}
}

View File

@ -1,9 +1,9 @@
//go:build windows
// +build windows // +build windows
package staros package staros
import ( import (
"log"
"syscall" "syscall"
"time" "time"
"unsafe" "unsafe"
@ -27,27 +27,33 @@ func IsRoot() bool {
return wincmd.Isas() return wincmd.Isas()
} }
func DiskUsage(path string) (disk DiskStatus) { func DiskUsage(path string) (disk DiskStatus) {
kernel32, err := syscall.LoadLibrary("Kernel32.dll") disk, _ = DiskUsageE(path)
if err != nil { return
log.Panic(err) }
}
defer syscall.FreeLibrary(kernel32)
GetDiskFreeSpaceEx, err := syscall.GetProcAddress(syscall.Handle(kernel32), "GetDiskFreeSpaceExW")
if err != nil { func DiskUsageE(path string) (disk DiskStatus, err error) {
log.Panic(err) if path == "" {
path = "."
} }
lpFreeBytesAvailable := int64(0) lpFreeBytesAvailable := int64(0)
lpTotalNumberOfBytes := int64(0) lpTotalNumberOfBytes := int64(0)
lpTotalNumberOfFreeBytes := int64(0) lpTotalNumberOfFreeBytes := int64(0)
syscall.Syscall6(uintptr(GetDiskFreeSpaceEx), 4,
uintptr(unsafe.Pointer(syscall.StringToUTF16Ptr("C:"))), path16, err := syscall.UTF16PtrFromString(path)
if err != nil {
return
}
r1, _, callErr := syscall.NewLazyDLL("kernel32.dll").NewProc("GetDiskFreeSpaceExW").Call(
uintptr(unsafe.Pointer(path16)),
uintptr(unsafe.Pointer(&lpFreeBytesAvailable)), uintptr(unsafe.Pointer(&lpFreeBytesAvailable)),
uintptr(unsafe.Pointer(&lpTotalNumberOfBytes)), uintptr(unsafe.Pointer(&lpTotalNumberOfBytes)),
uintptr(unsafe.Pointer(&lpTotalNumberOfFreeBytes)), 0, 0) uintptr(unsafe.Pointer(&lpTotalNumberOfFreeBytes)),
)
if r1 == 0 {
err = callErr
return
}
disk.Free = uint64(lpTotalNumberOfFreeBytes) disk.Free = uint64(lpTotalNumberOfFreeBytes)
disk.Used = uint64(lpTotalNumberOfBytes - lpTotalNumberOfFreeBytes) disk.Used = uint64(lpTotalNumberOfBytes - lpTotalNumberOfFreeBytes)
disk.All = uint64(lpTotalNumberOfBytes) disk.All = uint64(lpTotalNumberOfBytes)

File diff suppressed because it is too large Load Diff

73
process_darwin.go Normal file
View File

@ -0,0 +1,73 @@
//go:build darwin
// +build darwin
package staros
import "sync/atomic"
func FindProcessByName(name string) (datas []Process, err error) {
return nil, ERR_UNSUPPORTED
}
func FindProcess(compare func(Process) bool) (datas []Process, err error) {
return nil, ERR_UNSUPPORTED
}
func FindProcessByPid(pid int64) (datas Process, err error) {
return datas, ERR_UNSUPPORTED
}
func Daemon(path string, args ...string) (int, error) {
return -1, ERR_UNSUPPORTED
}
func DaemonWithUser(uid, gid uint32, groups []uint32, path string, args ...string) (int, error) {
return -1, ERR_UNSUPPORTED
}
func (starcli *StarCmd) SetRunUser(uid, gid uint32, groups []uint32) {
_ = starcli.SetRunUserE(uid, gid, groups)
}
func (starcli *StarCmd) SetRunUserE(uid, gid uint32, groups []uint32) error {
if starcli == nil || starcli.CMD == nil {
return errNilCommand
}
if atomic.LoadInt32(&starcli.started) != 0 {
return errCommandAlreadyStarted
}
return ERR_UNSUPPORTED
}
func (starcli *StarCmd) Release() error {
return starcli.ReleaseE()
}
func (starcli *StarCmd) Detach() error {
return starcli.DetachE()
}
func (starcli *StarCmd) ReleaseE() error {
if starcli == nil || starcli.CMD == nil {
return errNilCommand
}
return ERR_UNSUPPORTED
}
func (starcli *StarCmd) DetachE() error {
if starcli == nil || starcli.CMD == nil {
return errNilCommand
}
return ERR_UNSUPPORTED
}
func (starcli *StarCmd) SetKeepCaps() error {
if starcli == nil || starcli.CMD == nil {
return errNilCommand
}
return ERR_UNSUPPORTED
}
func SetKeepCaps() error {
return ERR_UNSUPPORTED
}

110
process_linux_test.go Normal file
View File

@ -0,0 +1,110 @@
//go:build linux
// +build linux
package staros
import (
"errors"
"reflect"
"syscall"
"testing"
)
func TestStarCmdSetKeepCapsConfiguresAmbientCaps(t *testing.T) {
command, args := testCommandArgs("exit 0")
cmd, err := Command(command, args...)
if err != nil {
t.Fatal(err)
}
original := loadCurrentKeepCaps
loadCurrentKeepCaps = func() ([]uintptr, error) {
return []uintptr{7, 1, 7}, nil
}
t.Cleanup(func() {
loadCurrentKeepCaps = original
})
cmd.CMD.SysProcAttr = &syscall.SysProcAttr{
AmbientCaps: []uintptr{9, 1},
}
if err := cmd.SetKeepCaps(); err != nil {
t.Fatal(err)
}
want := []uintptr{1, 7, 9}
if got := cmd.CMD.SysProcAttr.AmbientCaps; !reflect.DeepEqual(got, want) {
t.Fatalf("unexpected ambient caps: got=%v want=%v", got, want)
}
}
func TestStarCmdSetKeepCapsPropagatesCapabilityReadError(t *testing.T) {
command, args := testCommandArgs("exit 0")
cmd, err := Command(command, args...)
if err != nil {
t.Fatal(err)
}
wantErr := errors.New("capget failed")
original := loadCurrentKeepCaps
loadCurrentKeepCaps = func() ([]uintptr, error) {
return nil, wantErr
}
t.Cleanup(func() {
loadCurrentKeepCaps = original
})
if err := cmd.SetKeepCaps(); !errors.Is(err, wantErr) {
t.Fatalf("expected keepcaps read error, got %v", err)
}
}
func TestStarCmdSetRunUserPreservesExistingSysProcAttr(t *testing.T) {
command, args := testCommandArgs("exit 0")
cmd, err := Command(command, args...)
if err != nil {
t.Fatal(err)
}
original := loadCurrentKeepCaps
loadCurrentKeepCaps = func() ([]uintptr, error) {
return []uintptr{7, 1, 7}, nil
}
t.Cleanup(func() {
loadCurrentKeepCaps = original
})
cmd.CMD.SysProcAttr = &syscall.SysProcAttr{
Pdeathsig: syscall.SIGTERM,
AmbientCaps: []uintptr{9},
}
if err := cmd.SetKeepCaps(); err != nil {
t.Fatal(err)
}
groups := []uint32{3, 4}
if err := cmd.SetRunUserE(1, 2, groups); err != nil {
t.Fatal(err)
}
groups[0] = 99
if got, want := cmd.CMD.SysProcAttr.AmbientCaps, []uintptr{1, 7, 9}; !reflect.DeepEqual(got, want) {
t.Fatalf("ambient caps lost after SetRunUserE: got=%v want=%v", got, want)
}
if got := cmd.CMD.SysProcAttr.Pdeathsig; got != syscall.SIGTERM {
t.Fatalf("expected Pdeathsig to be preserved, got %v", got)
}
if !cmd.CMD.SysProcAttr.Setsid {
t.Fatal("expected Setsid to be enabled")
}
cred := cmd.CMD.SysProcAttr.Credential
if cred == nil {
t.Fatal("expected credential to be configured")
}
if cred.Uid != 1 || cred.Gid != 2 {
t.Fatalf("unexpected credential ids: uid=%d gid=%d", cred.Uid, cred.Gid)
}
if got, want := cred.Groups, []uint32{3, 4}; !reflect.DeepEqual(got, want) {
t.Fatalf("unexpected credential groups: got=%v want=%v", got, want)
}
}

View File

@ -1,27 +1,813 @@
package staros package staros
import ( import (
"bytes"
"context" "context"
"fmt" "encoding/base64"
"encoding/binary"
"errors"
"io/ioutil"
"os"
"path/filepath"
"runtime"
"strings"
"testing" "testing"
"time" "time"
"unicode/utf16"
) )
func Test_Process(t *testing.T) { func testCommandArgs(script string) (string, []string) {
fmt.Println(FindProcessByPid(16652)) if runtime.GOOS == "windows" {
} return "cmd.exe", []string{"/c", script}
}
func Test_StarCmd(t *testing.T) { return "sh", []string{"-c", script}
ctx, _ := context.WithTimeout(context.Background(), time.Second*5) }
cmd, _ := CommandContext(ctx, "cmd.exe", "/c", "ping -t 127.0.0.1")
cmd.Start() func testWindowsPowerShellArgs(script string) (string, []string) {
for cmd.IsRunning() { utf16Script := utf16.Encode([]rune(script))
fmt.Print(cmd.NowLineOutput()) encoded := make([]byte, len(utf16Script)*2)
time.Sleep(time.Millisecond * 50) for i, r := range utf16Script {
binary.LittleEndian.PutUint16(encoded[i*2:], uint16(r))
}
return "powershell.exe", []string{"-NoProfile", "-EncodedCommand", base64.StdEncoding.EncodeToString(encoded)}
}
type closeTrackingWriteCloser struct {
closed bool
}
func (w *closeTrackingWriteCloser) Write(data []byte) (int, error) {
return len(data), nil
}
func (w *closeTrackingWriteCloser) Close() error {
w.closed = true
return nil
}
func TestStarCmdCapturesOutputAndExitCode(t *testing.T) {
script := "printf 'hello'; printf 'err' 1>&2"
command, args := testCommandArgs(script)
if runtime.GOOS == "windows" {
command, args = testWindowsPowerShellArgs("[Console]::Out.Write('hello'); [Console]::Error.Write('err')")
}
cmd, err := Command(command, args...)
if err != nil {
t.Fatal(err)
}
if err := cmd.Start(); err != nil {
t.Fatal(err)
}
<-cmd.Stoped()
out, outErr := cmd.AllOutPut()
if out != "hello" {
t.Fatalf("expected stdout %q, got %q", "hello", out)
}
if outErr == nil || outErr.Error() != "err" {
t.Fatalf("expected stderr error %q, got %v", "err", outErr)
}
if got := cmd.ExitCode(); got != 0 {
t.Fatalf("expected exit code 0, got %d", got)
}
}
func TestStarCmdWaitReturnsProcessError(t *testing.T) {
command, args := testCommandArgs("exit 7")
cmd, err := Command(command, args...)
if err != nil {
t.Fatal(err)
}
if err := cmd.Wait(); !errors.Is(err, errCommandProcessNotStarted) {
t.Fatalf("expected errCommandProcessNotStarted before start, got %v", err)
}
if err := cmd.Start(); err != nil {
t.Fatal(err)
}
if err := cmd.Wait(); err == nil {
t.Fatal("expected wait error for non-zero exit")
}
if got := cmd.ExitCode(); got != 7 {
t.Fatalf("expected exit code 7, got %d", got)
}
if err := cmd.Wait(); err == nil {
t.Fatal("expected repeated Wait to keep final process error")
}
}
func TestStarCmdWaitTimeoutAndContext(t *testing.T) {
command, args := testCommandArgs("sleep 1")
if runtime.GOOS == "windows" {
command, args = testCommandArgs("ping -n 2 127.0.0.1 >nul")
}
cmd, err := Command(command, args...)
if err != nil {
t.Fatal(err)
}
if err := cmd.Start(); err != nil {
t.Fatal(err)
}
if err := cmd.WaitTimeout(10 * time.Millisecond); !errors.Is(err, ERR_TIMEOUT) {
t.Fatalf("expected ERR_TIMEOUT, got %v", err)
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
defer cancel()
if err := cmd.WaitContext(ctx); !errors.Is(err, context.DeadlineExceeded) {
t.Fatalf("expected context deadline, got %v", err)
}
if err := cmd.WaitTimeout(3 * time.Second); err != nil {
t.Fatalf("expected command to finish, got %v", err)
}
}
func TestStarCmdWaitReturnsResultAfterProcessDone(t *testing.T) {
command, args := testCommandArgs("exit 0")
cmd, err := Command(command, args...)
if err != nil {
t.Fatal(err)
}
if err := cmd.Start(); err != nil {
t.Fatal(err)
}
if err := cmd.Wait(); err != nil {
t.Fatal(err)
}
if err := cmd.WaitTimeout(0); err != nil {
t.Fatalf("expected finished command to beat zero timeout, got %v", err)
}
ctx, cancel := context.WithCancel(context.Background())
cancel()
if err := cmd.WaitContext(ctx); err != nil {
t.Fatalf("expected finished command to beat canceled context, got %v", err)
}
}
func TestStarCmdWaitContextFinishedCommandWinsOverCanceledContext(t *testing.T) {
t.Run("success", func(t *testing.T) {
command, args := testCommandArgs("exit 0")
cmd, err := Command(command, args...)
if err != nil {
t.Fatal(err)
}
if err := cmd.Start(); err != nil {
t.Fatal(err)
}
if err := cmd.Wait(); err != nil {
t.Fatal(err)
}
ctx, cancel := context.WithCancel(context.Background())
cancel()
if err := cmd.WaitContext(ctx); err != nil {
t.Fatalf("finished successful command should win over canceled context, got %v", err)
}
})
t.Run("failed", func(t *testing.T) {
command, args := testCommandArgs("exit 7")
cmd, err := Command(command, args...)
if err != nil {
t.Fatal(err)
}
if err := cmd.Start(); err != nil {
t.Fatal(err)
}
waitErr := cmd.Wait()
if waitErr == nil {
t.Fatal("expected command wait error")
}
ctx, cancel := context.WithCancel(context.Background())
cancel()
if err := cmd.WaitContext(ctx); err == nil || err.Error() != waitErr.Error() {
t.Fatalf("finished failed command should win over canceled context, got %v, want %v", err, waitErr)
}
})
}
func TestStarCmdStoppedAlias(t *testing.T) {
command, args := testCommandArgs("exit 0")
cmd, err := Command(command, args...)
if err != nil {
t.Fatal(err)
}
if err := cmd.Start(); err != nil {
t.Fatal(err)
}
select {
case <-cmd.Stopped():
case <-time.After(time.Second):
t.Fatal("Stopped should close after command exits")
}
select {
case <-cmd.Stoped():
case <-time.After(time.Second):
t.Fatal("Stoped compatibility alias should close after command exits")
}
}
func TestStarCmdStopedPublishesFinalExitCode(t *testing.T) {
command, args := testCommandArgs("exit 7")
cmd, err := Command(command, args...)
if err != nil {
t.Fatal(err)
}
if err := cmd.Start(); err != nil {
t.Fatal(err)
}
<-cmd.Stoped()
if cmd.IsRunning() {
t.Fatal("command should not be running after Stoped closes")
}
if got := cmd.ExitCode(); got != 7 {
t.Fatalf("expected exit code 7, got %d", got)
}
}
func TestStarCmdRejectsRepeatedStart(t *testing.T) {
command, args := testCommandArgs("exit 0")
cmd, err := Command(command, args...)
if err != nil {
t.Fatal(err)
}
if err := cmd.Start(); err != nil {
t.Fatal(err)
}
if err := cmd.Start(); !errors.Is(err, errCommandAlreadyStarted) {
t.Fatalf("expected errCommandAlreadyStarted, got %v", err)
}
<-cmd.Stoped()
if err := cmd.Start(); !errors.Is(err, errCommandAlreadyStarted) {
t.Fatalf("expected errCommandAlreadyStarted after exit, got %v", err)
}
}
func TestStarCmdStartFailureClosesStoped(t *testing.T) {
cmd, err := Command("__staros_missing_command__")
if err != nil {
t.Fatal(err)
}
if err := cmd.Start(); err == nil {
t.Fatal("expected start failure")
}
select {
case <-cmd.Stoped():
case <-time.After(time.Second):
t.Fatal("Stoped should close after start failure")
}
if cmd.IsRunning() {
t.Fatal("command should not be running after start failure")
}
if got := cmd.ExitCode(); got != -1 {
t.Fatalf("expected exit code -1 after start failure, got %d", got)
}
}
func TestStarCmdCapturesLargeOutput(t *testing.T) {
expected := strings.Repeat("x", 256*1024)
script := "awk 'BEGIN{for(i=0;i<262144;i++) printf \"x\"}'"
command, args := testCommandArgs(script)
if runtime.GOOS == "windows" {
command, args = testWindowsPowerShellArgs("[Console]::Out.Write(('x' * 262144))")
}
cmd, err := Command(command, args...)
if err != nil {
t.Fatal(err)
}
if err := cmd.Start(); err != nil {
t.Fatal(err)
}
<-cmd.Stoped()
out, err := cmd.AllOutPut()
if err != nil {
t.Fatal(err)
}
if !bytes.Equal([]byte(out), []byte(expected)) {
t.Fatalf("expected %d stdout bytes, got %d", len(expected), len(out))
}
}
func TestStarCmdStreamsOutput(t *testing.T) {
script := "printf 'out'; printf 'err' 1>&2"
command, args := testCommandArgs(script)
if runtime.GOOS == "windows" {
command, args = testWindowsPowerShellArgs("[Console]::Out.Write('out'); [Console]::Error.Write('err')")
}
cmd, err := Command(command, args...)
if err != nil {
t.Fatal(err)
}
stdout := cmd.StdoutChan()
stderr := cmd.StderrChan()
output := cmd.OutputChan()
if err := cmd.Start(); err != nil {
t.Fatal(err)
}
var stdoutData, stderrData string
var outputData []StarCmdOutput
for stdout != nil || stderr != nil || output != nil {
select {
case data, ok := <-stdout:
if !ok {
stdout = nil
continue
}
stdoutData += string(data)
case data, ok := <-stderr:
if !ok {
stderr = nil
continue
}
stderrData += string(data)
case data, ok := <-output:
if !ok {
output = nil
continue
}
outputData = append(outputData, data)
case <-time.After(3 * time.Second):
t.Fatal("stream output timed out")
}
}
if stdoutData != "out" {
t.Fatalf("expected streamed stdout %q, got %q", "out", stdoutData)
}
if stderrData != "err" {
t.Fatalf("expected streamed stderr %q, got %q", "err", stderrData)
}
var seenStdout, seenStderr bool
for _, item := range outputData {
switch item.Stream {
case StarCmdOutputStdout:
seenStdout = seenStdout || string(item.Data) == "out"
case StarCmdOutputStderr:
seenStderr = seenStderr || string(item.Data) == "err"
default:
t.Fatalf("unknown output stream %v", item.Stream)
}
}
if !seenStdout || !seenStderr {
t.Fatalf("expected combined output stream to include stdout and stderr, got %#v", outputData)
}
}
func TestStarCmdStreamNilReturnsClosedChannels(t *testing.T) {
var cmd *StarCmd
select {
case _, ok := <-cmd.StdoutChan():
if ok {
t.Fatal("nil stdout stream should be closed")
}
case <-time.After(time.Second):
t.Fatal("nil stdout stream should close immediately")
}
select {
case _, ok := <-cmd.StderrChan():
if ok {
t.Fatal("nil stderr stream should be closed")
}
case <-time.After(time.Second):
t.Fatal("nil stderr stream should close immediately")
}
select {
case _, ok := <-cmd.OutputChan():
if ok {
t.Fatal("nil output stream should be closed")
}
case <-time.After(time.Second):
t.Fatal("nil output stream should close immediately")
}
}
func TestStarCmdRedirectOutputWriterKeepsCapture(t *testing.T) {
script := "printf 'out'; printf 'err' 1>&2"
command, args := testCommandArgs(script)
if runtime.GOOS == "windows" {
command, args = testWindowsPowerShellArgs("[Console]::Out.Write('out'); [Console]::Error.Write('err')")
}
cmd, err := Command(command, args...)
if err != nil {
t.Fatal(err)
}
var redirected bytes.Buffer
if err := cmd.RedirectOutput(&redirected); err != nil {
t.Fatal(err)
}
if err := cmd.Start(); err != nil {
t.Fatal(err)
}
<-cmd.Stopped()
if got := redirected.String(); got != "outerr" && got != "errout" {
t.Fatalf("expected redirected stdout/stderr bytes, got %q", got)
}
if out := cmd.AllStdOut(); out != "out" {
t.Fatalf("expected captured stdout %q, got %q", "out", out)
}
if err := cmd.AllStdErr(); err == nil || err.Error() != "err" {
t.Fatalf("expected captured stderr %q, got %v", "err", err)
}
}
func TestStarCmdRedirectFiles(t *testing.T) {
dir := t.TempDir()
stdoutFile := filepath.Join(dir, "stdout.txt")
stderrFile := filepath.Join(dir, "stderr.txt")
script := "printf 'out'; printf 'err' 1>&2"
command, args := testCommandArgs(script)
if runtime.GOOS == "windows" {
command, args = testWindowsPowerShellArgs("[Console]::Out.Write('out'); [Console]::Error.Write('err')")
}
cmd, err := Command(command, args...)
if err != nil {
t.Fatal(err)
}
if err := cmd.RedirectStdoutFile(stdoutFile); err != nil {
t.Fatal(err)
}
if err := cmd.RedirectStderrFile(stderrFile); err != nil {
t.Fatal(err)
}
if err := cmd.Start(); err != nil {
t.Fatal(err)
}
<-cmd.Stopped()
stdoutData, err := ioutil.ReadFile(stdoutFile)
if err != nil {
t.Fatal(err)
}
stderrData, err := ioutil.ReadFile(stderrFile)
if err != nil {
t.Fatal(err)
}
if string(stdoutData) != "out" {
t.Fatalf("expected stdout file %q, got %q", "out", string(stdoutData))
}
if string(stderrData) != "err" {
t.Fatalf("expected stderr file %q, got %q", "err", string(stderrData))
}
}
func TestStarCmdRedirectStdin(t *testing.T) {
command, args := testCommandArgs("cat")
if runtime.GOOS == "windows" {
command, args = testCommandArgs("more")
}
cmd, err := Command(command, args...)
if err != nil {
t.Fatal(err)
}
if err := cmd.RedirectStdin(strings.NewReader("hello\n")); err != nil {
t.Fatal(err)
}
if err := cmd.Start(); err != nil {
t.Fatal(err)
}
<-cmd.Stopped()
if out := cmd.AllStdOut(); !strings.Contains(out, "hello") {
t.Fatalf("expected redirected stdin in stdout, got %q", out)
}
if err := cmd.AllStdErr(); err != nil {
t.Fatalf("redirected stdin should not create command error, got %v", err)
}
if err := cmd.WriteCmdE("again"); !errors.Is(err, errCommandStdinUnavailable) {
t.Fatalf("expected errCommandStdinUnavailable after stdin redirect, got %v", err)
}
}
func TestStarCmdRedirectStdinClosesManagedPipe(t *testing.T) {
command, args := testCommandArgs("cat")
if runtime.GOOS == "windows" {
command, args = testCommandArgs("more")
}
cmd, err := Command(command, args...)
if err != nil {
t.Fatal(err)
}
tracker := &closeTrackingWriteCloser{}
cmd.lock.Lock()
cmd.infile = tracker
cmd.inclosed = false
cmd.lock.Unlock()
if err := cmd.RedirectStdin(strings.NewReader("hello\n")); err != nil {
t.Fatal(err)
}
if !tracker.closed {
t.Fatal("RedirectStdin should close the previously managed stdin pipe")
}
if err := cmd.WriteCmdE("again"); !errors.Is(err, errCommandStdinUnavailable) {
t.Fatalf("expected managed stdin to be unavailable after redirect, got %v", err)
}
}
func TestStarCmdDetachClosesManagedPipe(t *testing.T) {
command, args := testCommandArgs("exit 0")
cmd, err := Command(command, args...)
if err != nil {
t.Fatal(err)
}
tracker := &closeTrackingWriteCloser{}
cmd.lock.Lock()
original := cmd.infile
cmd.infile = tracker
cmd.inclosed = false
cmd.lock.Unlock()
if original != nil {
_ = original.Close()
}
if err := cmd.DetachE(); errors.Is(err, ERR_UNSUPPORTED) {
t.Skip(err)
} else if err != nil {
t.Fatal(err)
}
if !tracker.closed {
t.Fatal("DetachE should close the managed stdin pipe")
}
if err := cmd.WriteCmdE("again"); !errors.Is(err, errCommandStdinClosed) {
t.Fatalf("expected detached stdin to be closed, got %v", err)
}
}
func TestStarCmdRedirectRejectsInvalidState(t *testing.T) {
command, args := testCommandArgs("exit 0")
cmd, err := Command(command, args...)
if err != nil {
t.Fatal(err)
}
if err := cmd.RedirectStdout(nil); !errors.Is(err, errCommandRedirectNil) {
t.Fatalf("expected errCommandRedirectNil, got %v", err)
}
if err := cmd.Start(); err != nil {
t.Fatal(err)
}
if err := cmd.RedirectStdout(&bytes.Buffer{}); !errors.Is(err, errCommandAlreadyStarted) {
t.Fatalf("expected errCommandAlreadyStarted, got %v", err)
}
<-cmd.Stopped()
}
func TestStarCmdCloseStdinLetsCommandExit(t *testing.T) {
script := "cat"
if runtime.GOOS == "windows" {
script = "more"
}
command, args := testCommandArgs(script)
cmd, err := Command(command, args...)
if err != nil {
t.Fatal(err)
}
if err := cmd.Start(); err != nil {
t.Fatal(err)
}
if err := cmd.WriteCmdE("hello"); err != nil {
t.Fatal(err)
}
if err := cmd.CloseStdinE(); err != nil {
t.Fatal(err)
}
select {
case <-cmd.Stoped():
case <-time.After(3 * time.Second):
t.Fatal("command should exit after stdin closes")
}
if out := cmd.AllStdOut(); !strings.Contains(out, "hello") {
t.Fatalf("expected echoed stdin, got %q", out)
}
if err := cmd.CloseStdinE(); !errors.Is(err, errCommandStdinClosed) {
t.Fatalf("expected errCommandStdinClosed, got %v", err)
}
if err := cmd.WriteCmdE("again"); !errors.Is(err, errCommandStdinClosed) {
t.Fatalf("expected errCommandStdinClosed after close, got %v", err)
}
}
func TestStarCmdWriteStdinRawDoesNotAppendNewline(t *testing.T) {
script := "cat"
if runtime.GOOS == "windows" {
script = "more"
}
command, args := testCommandArgs(script)
cmd, err := Command(command, args...)
if err != nil {
t.Fatal(err)
}
if err := cmd.Start(); err != nil {
t.Fatal(err)
}
if err := cmd.WriteStdinStringE("raw"); err != nil {
t.Fatal(err)
}
if err := cmd.WriteStdinE([]byte("-bytes")); err != nil {
t.Fatal(err)
}
if err := cmd.CloseStdinE(); err != nil {
t.Fatal(err)
}
if err := cmd.WaitTimeout(3 * time.Second); err != nil {
t.Fatal(err)
}
if out := cmd.AllStdOut(); !strings.Contains(out, "raw-bytes") {
t.Fatalf("expected raw stdin without inserted newline, got %q", out)
}
}
func TestStarCmdNilGuards(t *testing.T) {
var cmd *StarCmd
if cmd.IsRunning() {
t.Fatal("nil StarCmd should not be running")
}
if got := cmd.GetPid(); got != -1 {
t.Fatalf("expected nil pid -1, got %d", got)
}
if err := cmd.Release(); !errors.Is(err, errNilCommand) {
t.Fatalf("expected errNilCommand, got %v", err)
}
if err := cmd.SetKeepCaps(); !errors.Is(err, errNilCommand) {
t.Fatalf("expected errNilCommand, got %v", err)
}
if err := cmd.ReleaseE(); !errors.Is(err, errNilCommand) {
t.Fatalf("expected errNilCommand, got %v", err)
}
if err := cmd.DetachE(); !errors.Is(err, errNilCommand) {
t.Fatalf("expected errNilCommand, got %v", err)
}
if err := cmd.SetRunUserE(0, 0, nil); !errors.Is(err, errNilCommand) {
t.Fatalf("expected errNilCommand, got %v", err)
}
if err := cmd.WriteCmdE("noop"); !errors.Is(err, errNilCommand) {
t.Fatalf("expected errNilCommand, got %v", err)
}
if err := cmd.WriteStdinE([]byte("noop")); !errors.Is(err, errNilCommand) {
t.Fatalf("expected errNilCommand, got %v", err)
}
if err := cmd.WriteStdinStringE("noop"); !errors.Is(err, errNilCommand) {
t.Fatalf("expected errNilCommand, got %v", err)
}
if err := cmd.WriteStdinLineE("noop"); !errors.Is(err, errNilCommand) {
t.Fatalf("expected errNilCommand, got %v", err)
}
if err := cmd.Wait(); !errors.Is(err, errNilCommand) {
t.Fatalf("expected errNilCommand, got %v", err)
}
if err := cmd.CloseStdinE(); !errors.Is(err, errNilCommand) {
t.Fatalf("expected errNilCommand, got %v", err)
}
}
func TestStarCmdReleaseUsesStartLifecycle(t *testing.T) {
command, args := testCommandArgs("exit 0")
cmd, err := Command(command, args...)
if err != nil {
t.Fatal(err)
}
if err := cmd.ReleaseE(); errors.Is(err, ERR_UNSUPPORTED) {
t.Skip(err)
} else if err != nil {
t.Fatal(err)
}
<-cmd.Stoped()
if got := cmd.ExitCode(); got != 0 {
t.Fatalf("expected exit code 0, got %d", got)
}
if err := cmd.ReleaseE(); !errors.Is(err, errCommandAlreadyReleased) {
t.Fatalf("expected errCommandAlreadyReleased, got %v", err)
}
}
func TestStarCmdReleaseAfterStartKeepsLifecycle(t *testing.T) {
command, args := testCommandArgs("exit 0")
cmd, err := Command(command, args...)
if err != nil {
t.Fatal(err)
}
if err := cmd.Start(); err != nil {
t.Fatal(err)
}
if err := cmd.ReleaseE(); errors.Is(err, ERR_UNSUPPORTED) {
t.Skip(err)
} else if err != nil {
t.Fatal(err)
}
<-cmd.Stoped()
if got := cmd.ExitCode(); got != 0 {
t.Fatalf("expected exit code 0, got %d", got)
}
}
func TestStarCmdDetachRejectsRepeatedDetach(t *testing.T) {
command, args := testCommandArgs("exit 0")
cmd, err := Command(command, args...)
if err != nil {
t.Fatal(err)
}
if err := cmd.DetachE(); errors.Is(err, ERR_UNSUPPORTED) {
t.Skip(err)
} else if err != nil {
t.Fatal(err)
}
if err := cmd.DetachE(); !errors.Is(err, errCommandAlreadyDetached) {
t.Fatalf("expected errCommandAlreadyDetached, got %v", err)
}
if err := cmd.Start(); !errors.Is(err, errCommandDetached) {
t.Fatalf("expected errCommandDetached, got %v", err)
}
}
func TestStarCmdDetachPublishesWaitResult(t *testing.T) {
command, args := testCommandArgs("exit 0")
cmd, err := Command(command, args...)
if err != nil {
t.Fatal(err)
}
if err := cmd.DetachE(); errors.Is(err, ERR_UNSUPPORTED) {
t.Skip(err)
} else if err != nil {
t.Fatal(err)
}
if err := cmd.WaitTimeout(0); err != nil {
t.Fatalf("detached command should publish final wait result, got %v", err)
}
ctx, cancel := context.WithCancel(context.Background())
cancel()
if err := cmd.WaitContext(ctx); err != nil {
t.Fatalf("detached command should beat canceled context, got %v", err)
}
waitErr := make(chan error, 1)
go func() {
waitErr <- cmd.Wait()
}()
select {
case err := <-waitErr:
if err != nil {
t.Fatalf("detached command wait got %v", err)
}
case <-time.After(time.Second):
t.Fatal("detached command wait did not observe final result")
}
}
func TestStarCmdDetachDoesNotCaptureOutput(t *testing.T) {
script := "printf 'detached'; printf 'err' 1>&2"
if runtime.GOOS == "windows" {
script = "<nul set /p =detached & <nul set /p =err 1>&2"
}
command, args := testCommandArgs(script)
cmd, err := Command(command, args...)
if err != nil {
t.Fatal(err)
}
if err := cmd.DetachE(); errors.Is(err, ERR_UNSUPPORTED) {
t.Skip(err)
} else if err != nil {
t.Fatal(err)
}
<-cmd.Stoped()
if out := cmd.AllStdOut(); out != "" {
t.Fatalf("detached command should not be captured, got stdout %q", out)
}
if err := cmd.AllStdErr(); err != nil {
t.Fatalf("detached command should not capture stderr, got %v", err)
}
}
func TestStarCmdDetachRejectsStartedCommand(t *testing.T) {
command, args := testCommandArgs("exit 0")
cmd, err := Command(command, args...)
if err != nil {
t.Fatal(err)
}
if err := cmd.Start(); err != nil {
t.Fatal(err)
}
if err := cmd.DetachE(); errors.Is(err, ERR_UNSUPPORTED) {
t.Skip(err)
} else if !errors.Is(err, errCommandAlreadyStarted) {
t.Fatalf("expected errCommandAlreadyStarted, got %v", err)
}
<-cmd.Stoped()
}
func TestFindProcessByPidCurrentProcess(t *testing.T) {
pid := os.Getpid()
process, err := FindProcessByPid(int64(pid))
if errors.Is(err, ERR_UNSUPPORTED) {
t.Skip(err)
}
if err != nil {
t.Fatal(err)
}
if process.Pid != int64(pid) {
t.Fatalf("expected pid %d, got %d", pid, process.Pid)
}
}
func TestStopedNilReturnsClosedChannel(t *testing.T) {
var cmd *StarCmd
select {
case <-cmd.Stoped():
case <-time.After(time.Second):
t.Fatal("nil Stoped channel should already be closed")
}
select {
case <-cmd.Stopped():
case <-time.After(time.Second):
t.Fatal("nil Stopped channel should already be closed")
} }
fmt.Println(cmd.NowAllOutput())
fmt.Print("all is ")
fmt.Println(cmd.AllOutPut())
fmt.Println(cmd.ExitCode())
} }

View File

@ -1,4 +1,5 @@
// +build linux darwin //go:build linux
// +build linux
package staros package staros
@ -10,13 +11,19 @@ import (
"os" "os"
"os/exec" "os/exec"
"path/filepath" "path/filepath"
"sort"
"strconv" "strconv"
"strings" "strings"
"sync/atomic"
"syscall" "syscall"
"time" "time"
"golang.org/x/sys/unix"
) )
//FindProcessByName 通过进程名来查询应用信息 var loadCurrentKeepCaps = currentKeepCaps
// FindProcessByName 通过进程名来查询应用信息
func FindProcessByName(name string) (datas []Process, err error) { func FindProcessByName(name string) (datas []Process, err error) {
return FindProcess(func(in Process) bool { return FindProcess(func(in Process) bool {
if name == in.Name { if name == in.Name {
@ -30,43 +37,11 @@ func FindProcessByName(name string) (datas []Process, err error) {
func FindProcess(compare func(Process) bool) (datas []Process, err error) { func FindProcess(compare func(Process) bool) (datas []Process, err error) {
var name, main string var name, main string
var mainb []byte var mainb []byte
var netErr error netSnapshot := loadNetSnapshot(false)
var netInfo []NetConn
paths, err := ioutil.ReadDir("/proc") paths, err := ioutil.ReadDir("/proc")
if err != nil { if err != nil {
return return
} }
netInfo, netErr = NetConnections(false, "")
appendNetInfo := func(p *Process) {
if netErr != nil {
p.netErr = netErr
return
}
fds, err := ioutil.ReadDir("/proc/" + strconv.Itoa(int(p.Pid)) + "/fd")
if err != nil && Exists("/proc/"+strconv.Itoa(int(p.Pid))+"/fd") {
p.netErr = err
return
}
for _, fd := range fds {
socket, err := os.Readlink("/proc/" + strconv.Itoa(int(p.Pid)) + "/fd/" + fd.Name())
if err != nil {
p.netErr = err
return
}
start := strings.Index(socket, "[")
if start < 0 {
continue
}
sid := socket[start+1 : len(socket)-1]
for _, v := range netInfo {
if v.Inode == sid {
v.Pid = p.Pid
v.Process = p
p.netConn = append(p.netConn, v)
}
}
}
}
for _, v := range paths { for _, v := range paths {
if v.IsDir() && Exists("/proc/"+v.Name()+"/comm") { if v.IsDir() && Exists("/proc/"+v.Name()+"/comm") {
name, err = readAsString("/proc/" + v.Name() + "/comm") name, err = readAsString("/proc/" + v.Name() + "/comm")
@ -83,7 +58,7 @@ func FindProcess(compare func(Process) bool) (datas []Process, err error) {
if err != nil { if err != nil {
tmp.Err = err tmp.Err = err
if compare(tmp) { if compare(tmp) {
appendNetInfo(&tmp) netSnapshot.appendTo(&tmp)
datas = append(datas, tmp) datas = append(datas, tmp)
continue continue
} }
@ -94,28 +69,22 @@ func FindProcess(compare func(Process) bool) (datas []Process, err error) {
tmp.TPid, _ = strconv.ParseInt(data["TracerPid"], 10, 64) tmp.TPid, _ = strconv.ParseInt(data["TracerPid"], 10, 64)
uids := splitBySpace(data["Uid"]) uids := splitBySpace(data["Uid"])
gids := splitBySpace(data["Gid"]) gids := splitBySpace(data["Gid"])
tmp.RUID, _ = strconv.Atoi(uids[0]) tmp.RUID, _ = atoiField(uids, 0)
tmp.EUID, _ = strconv.Atoi(uids[1]) tmp.EUID, _ = atoiField(uids, 1)
tmp.RGID, _ = strconv.Atoi(gids[0]) tmp.RGID, _ = atoiField(gids, 0)
tmp.EGID, _ = strconv.Atoi(gids[1]) tmp.EGID, _ = atoiField(gids, 1)
tmp.VmPeak, _ = strconv.ParseInt(splitBySpace(data["VmPeak"])[0], 10, 64) tmp.VmPeak = parseProcStatusKB(data["VmPeak"])
tmp.VmSize, _ = strconv.ParseInt(splitBySpace(data["VmSize"])[0], 10, 64) tmp.VmSize = parseProcStatusKB(data["VmSize"])
tmp.VmHWM, _ = strconv.ParseInt(splitBySpace(data["VmHWM"])[0], 10, 64) tmp.VmHWM = parseProcStatusKB(data["VmHWM"])
tmp.VmRSS, _ = strconv.ParseInt(splitBySpace(data["VmRSS"])[0], 10, 64) tmp.VmRSS = parseProcStatusKB(data["VmRSS"])
tmp.VmLck, _ = strconv.ParseInt(splitBySpace(data["VmLck"])[0], 10, 64) tmp.VmLck = parseProcStatusKB(data["VmLck"])
tmp.VmData, _ = strconv.ParseInt(splitBySpace(data["VmData"])[0], 10, 64) tmp.VmData = parseProcStatusKB(data["VmData"])
tmp.VmLck *= 1024
tmp.VmData *= 1024
tmp.VmPeak *= 1024
tmp.VmSize *= 1024
tmp.VmHWM *= 1024
tmp.VmRSS *= 1024
} }
mainb, err = ioutil.ReadFile("/proc/" + v.Name() + "/cmdline") mainb, err = ioutil.ReadFile("/proc/" + v.Name() + "/cmdline")
if err != nil { if err != nil {
tmp.Err = err tmp.Err = err
if compare(tmp) { if compare(tmp) {
appendNetInfo(&tmp) netSnapshot.appendTo(&tmp)
datas = append(datas, tmp) datas = append(datas, tmp)
continue continue
} }
@ -129,7 +98,7 @@ func FindProcess(compare func(Process) bool) (datas []Process, err error) {
if err != nil { if err != nil {
tmp.Err = err tmp.Err = err
if compare(tmp) { if compare(tmp) {
appendNetInfo(&tmp) netSnapshot.appendTo(&tmp)
datas = append(datas, tmp) datas = append(datas, tmp)
continue continue
} }
@ -144,17 +113,15 @@ func FindProcess(compare func(Process) bool) (datas []Process, err error) {
if err != nil { if err != nil {
tmp.Err = err tmp.Err = err
if compare(tmp) { if compare(tmp) {
appendNetInfo(&tmp) netSnapshot.appendTo(&tmp)
datas = append(datas, tmp) datas = append(datas, tmp)
continue continue
} }
} else { } else if uptime, ok := procStartTimeFromStat([]byte(main)); ok {
times := splitBySpace(main) tmp.Uptime = uptime
uptime, _ := strconv.ParseInt(strings.TrimSpace(times[21]), 10, 64)
tmp.Uptime = time.Unix(StartTime().Unix()+uptime/100, int64((float64(uptime)/100-float64(uptime/100))*1000000000))
} }
if compare(tmp) { if compare(tmp) {
appendNetInfo(&tmp) netSnapshot.appendTo(&tmp)
datas = append(datas, tmp) datas = append(datas, tmp)
} }
} }
@ -170,38 +137,6 @@ func FindProcessByPid(pid int64) (datas Process, err error) {
err = errors.New("Not Found") err = errors.New("Not Found")
return return
} }
netInfo, netErr := NetConnections(false, "")
appendNetInfo := func(p *Process) {
if netErr != nil {
p.netErr = netErr
return
}
fds, err := ioutil.ReadDir("/proc/" + strconv.Itoa(int(p.Pid)) + "/fd")
if err != nil && Exists("/proc/"+strconv.Itoa(int(p.Pid))+"/fd") {
p.netErr = err
return
}
for _, fd := range fds {
socket, err := os.Readlink("/proc/" + strconv.Itoa(int(p.Pid)) + "/fd/" + fd.Name())
if err != nil {
p.netErr = err
return
}
start := strings.Index(socket, "[")
if start < 0 {
continue
}
sid := socket[start+1 : len(socket)-1]
for _, v := range netInfo {
if v.Inode == sid {
v.Pid = p.Pid
v.Process = p
p.netConn = append(p.netConn, v)
}
}
}
}
name, err = readAsString("/proc/" + fmt.Sprint(pid) + "/comm") name, err = readAsString("/proc/" + fmt.Sprint(pid) + "/comm")
if err != nil { if err != nil {
return return
@ -217,23 +152,17 @@ func FindProcessByPid(pid int64) (datas Process, err error) {
datas.TPid, _ = strconv.ParseInt(data["TracerPid"], 10, 64) datas.TPid, _ = strconv.ParseInt(data["TracerPid"], 10, 64)
uids := splitBySpace(data["Uid"]) uids := splitBySpace(data["Uid"])
gids := splitBySpace(data["Gid"]) gids := splitBySpace(data["Gid"])
datas.RUID, _ = strconv.Atoi(uids[0]) datas.RUID, _ = atoiField(uids, 0)
datas.EUID, _ = strconv.Atoi(uids[1]) datas.EUID, _ = atoiField(uids, 1)
datas.RGID, _ = strconv.Atoi(gids[0]) datas.RGID, _ = atoiField(gids, 0)
datas.EGID, _ = strconv.Atoi(gids[1]) datas.EGID, _ = atoiField(gids, 1)
datas.VmPeak, _ = strconv.ParseInt(splitBySpace(data["VmPeak"])[0], 10, 64) datas.VmPeak = parseProcStatusKB(data["VmPeak"])
datas.VmSize, _ = strconv.ParseInt(splitBySpace(data["VmSize"])[0], 10, 64) datas.VmSize = parseProcStatusKB(data["VmSize"])
datas.VmHWM, _ = strconv.ParseInt(splitBySpace(data["VmHWM"])[0], 10, 64) datas.VmHWM = parseProcStatusKB(data["VmHWM"])
datas.VmRSS, _ = strconv.ParseInt(splitBySpace(data["VmRSS"])[0], 10, 64) datas.VmRSS = parseProcStatusKB(data["VmRSS"])
datas.VmLck, _ = strconv.ParseInt(splitBySpace(data["VmLck"])[0], 10, 64) datas.VmLck = parseProcStatusKB(data["VmLck"])
datas.VmData, _ = strconv.ParseInt(splitBySpace(data["VmData"])[0], 10, 64) datas.VmData = parseProcStatusKB(data["VmData"])
datas.VmLck *= 1024 loadNetSnapshot(false).appendTo(&datas)
datas.VmData *= 1024
datas.VmPeak *= 1024
datas.VmSize *= 1024
datas.VmHWM *= 1024
datas.VmRSS *= 1024
appendNetInfo(&datas)
mainb, err = ioutil.ReadFile("/proc/" + fmt.Sprint(pid) + "/cmdline") mainb, err = ioutil.ReadFile("/proc/" + fmt.Sprint(pid) + "/cmdline")
if err != nil { if err != nil {
datas.Err = err datas.Err = err
@ -264,12 +193,92 @@ func FindProcessByPid(pid int64) (datas Process, err error) {
if err != nil { if err != nil {
return return
} }
times := splitBySpace(main) if uptime, ok := procStartTimeFromStat([]byte(main)); ok {
uptime, _ := strconv.ParseInt(strings.TrimSpace(times[21]), 10, 64) datas.Uptime = uptime
datas.Uptime = time.Unix(StartTime().Unix()+uptime/100, int64((float64(uptime)/100-float64(uptime/100))*1000000000)) }
return return
} }
func procStartTimeFromStat(content []byte) (time.Time, bool) {
fields := splitProcStat(content)
if len(fields) <= 22 {
return time.Time{}, false
}
startTicks, err := strconv.ParseInt(strings.TrimSpace(fields[22]), 10, 64)
if err != nil {
return time.Time{}, false
}
ticks := int64(clockTicks())
seconds := startTicks / ticks
nanos := (startTicks % ticks) * int64(time.Second) / ticks
return time.Unix(StartTime().Unix()+seconds, nanos), true
}
func atoiField(fields []string, index int) (int, error) {
if index < 0 || index >= len(fields) {
return 0, errors.New("field index out of range")
}
return strconv.Atoi(fields[index])
}
func parseProcStatusKB(value string) int64 {
fields := splitBySpace(value)
if len(fields) == 0 || fields[0] == "" {
return 0
}
size, err := strconv.ParseInt(fields[0], 10, 64)
if err != nil {
return 0
}
return size * 1024
}
type netSnapshot struct {
conns []NetConn
err error
}
func loadNetSnapshot(analysePid bool) netSnapshot {
netInfo, err := NetConnections(analysePid, "")
return netSnapshot{conns: netInfo, err: err}
}
func appendNetInfo(p *Process, analysePid bool) {
loadNetSnapshot(analysePid).appendTo(p)
}
func (snapshot netSnapshot) appendTo(p *Process) {
if snapshot.err != nil {
p.netErr = snapshot.err
return
}
fds, err := ioutil.ReadDir("/proc/" + strconv.Itoa(int(p.Pid)) + "/fd")
if err != nil {
if Exists("/proc/" + strconv.Itoa(int(p.Pid)) + "/fd") {
p.netErr = err
}
return
}
for _, fd := range fds {
socket, err := os.Readlink("/proc/" + strconv.Itoa(int(p.Pid)) + "/fd/" + fd.Name())
if err != nil {
continue
}
start := strings.Index(socket, "[")
if start < 0 {
continue
}
sid := socket[start+1 : len(socket)-1]
for _, v := range snapshot.conns {
if v.Inode == sid {
v.Pid = p.Pid
v.Process = p
p.netConn = append(p.netConn, v)
}
}
}
}
func Daemon(path string, args ...string) (int, error) { func Daemon(path string, args ...string) (int, error) {
cmd := exec.Command(path, args...) cmd := exec.Command(path, args...)
cmd.SysProcAttr = &syscall.SysProcAttr{ cmd.SysProcAttr = &syscall.SysProcAttr{
@ -302,17 +311,58 @@ func DaemonWithUser(uid, gid uint32, groups []uint32, path string, args ...strin
} }
func (starcli *StarCmd) SetRunUser(uid, gid uint32, groups []uint32) { func (starcli *StarCmd) SetRunUser(uid, gid uint32, groups []uint32) {
starcli.CMD.SysProcAttr = &syscall.SysProcAttr{ _ = starcli.SetRunUserE(uid, gid, groups)
Credential: &syscall.Credential{ }
Uid: uid,
Gid: gid, func (starcli *StarCmd) SetRunUserE(uid, gid uint32, groups []uint32) error {
Groups: groups, if starcli == nil || starcli.CMD == nil {
}, return errNilCommand
Setsid: true,
} }
if atomic.LoadInt32(&starcli.started) != 0 {
return errCommandAlreadyStarted
}
if starcli.CMD.SysProcAttr == nil {
starcli.CMD.SysProcAttr = &syscall.SysProcAttr{}
}
if starcli.CMD.SysProcAttr.Credential == nil {
starcli.CMD.SysProcAttr.Credential = &syscall.Credential{}
}
starcli.CMD.SysProcAttr.Credential.Uid = uid
starcli.CMD.SysProcAttr.Credential.Gid = gid
starcli.CMD.SysProcAttr.Credential.Groups = append([]uint32(nil), groups...)
starcli.CMD.SysProcAttr.Setsid = true
return nil
} }
func (starcli *StarCmd) Release() error { func (starcli *StarCmd) Release() error {
return starcli.ReleaseE()
}
func (starcli *StarCmd) Detach() error {
return starcli.DetachE()
}
func (starcli *StarCmd) ReleaseE() error {
if starcli == nil || starcli.CMD == nil {
return errNilCommand
}
if !atomic.CompareAndSwapInt32(&starcli.released, 0, 1) {
return errCommandAlreadyReleased
}
if atomic.LoadInt32(&starcli.started) != 0 {
if starcli.CMD.Process == nil {
starcli.lock.Lock()
err := starcli.runerr
starcli.lock.Unlock()
if err != nil {
atomic.StoreInt32(&starcli.released, 0)
return err
}
atomic.StoreInt32(&starcli.released, 0)
return errCommandAlreadyStarted
}
return nil
}
if starcli.CMD.SysProcAttr == nil { if starcli.CMD.SysProcAttr == nil {
starcli.CMD.SysProcAttr = &syscall.SysProcAttr{ starcli.CMD.SysProcAttr = &syscall.SysProcAttr{
Setsid: true, Setsid: true,
@ -322,27 +372,121 @@ func (starcli *StarCmd) Release() error {
starcli.CMD.SysProcAttr.Setsid = true starcli.CMD.SysProcAttr.Setsid = true
} }
} }
if !starcli.IsRunning() { if err := starcli.Start(); err != nil {
if err := starcli.CMD.Start(); err != nil { atomic.StoreInt32(&starcli.released, 0)
return err return err
} }
return nil
}
func (starcli *StarCmd) DetachE() error {
if starcli == nil || starcli.CMD == nil {
return errNilCommand
} }
time.Sleep(time.Millisecond * 10) if !atomic.CompareAndSwapInt32(&starcli.detached, 0, 1) {
return starcli.CMD.Process.Release() return errCommandAlreadyDetached
}
if atomic.LoadInt32(&starcli.started) != 0 {
atomic.StoreInt32(&starcli.detached, 0)
return errCommandAlreadyStarted
}
cmd := exec.Command(starcli.CMD.Path, starcli.CMD.Args[1:]...)
cmd.Dir = starcli.CMD.Dir
cmd.Env = append([]string(nil), starcli.CMD.Env...)
if starcli.CMD.SysProcAttr != nil {
attr := *starcli.CMD.SysProcAttr
cmd.SysProcAttr = &attr
} else {
cmd.SysProcAttr = &syscall.SysProcAttr{
Setsid: true,
}
}
if !cmd.SysProcAttr.Setsid {
cmd.SysProcAttr.Setsid = true
}
devNull, err := os.OpenFile(os.DevNull, os.O_RDWR, 0)
if err != nil {
atomic.StoreInt32(&starcli.detached, 0)
return err
}
defer devNull.Close()
cmd.Stdin = devNull
cmd.Stdout = devNull
cmd.Stderr = devNull
if err := cmd.Start(); err != nil {
atomic.StoreInt32(&starcli.detached, 0)
return err
}
starcli.CMD.Process = cmd.Process
atomic.StoreInt32(&starcli.started, 1)
starcli.setRunning(false)
starcli.finish()
if err := cmd.Process.Release(); err != nil {
atomic.StoreInt32(&starcli.detached, 0)
return err
}
return nil
} }
func (starcli *StarCmd) SetKeepCaps() error { func (starcli *StarCmd) SetKeepCaps() error {
_, _, err := syscall.RawSyscall(157 /*SYS PRCTL */, 0x8 /*PR SET KEEPCAPS*/, 1, 0) if err := starcli.ensureConfigurable(); err != nil {
if 0 != err {
return err return err
} }
caps, err := loadCurrentKeepCaps()
if err != nil {
return err
}
starcli.lock.Lock()
defer starcli.lock.Unlock()
if starcli.CMD.SysProcAttr == nil {
starcli.CMD.SysProcAttr = &syscall.SysProcAttr{}
}
starcli.CMD.SysProcAttr.AmbientCaps = mergeAmbientCaps(starcli.CMD.SysProcAttr.AmbientCaps, caps)
return nil return nil
} }
func SetKeepCaps() error { func SetKeepCaps() error {
_, _, err := syscall.RawSyscall(157 /*SYS PRCTL */, 0x8 /*PR SET KEEPCAPS*/, 1, 0) return unix.Prctl(unix.PR_SET_KEEPCAPS, 1, 0, 0, 0)
if 0 != err { }
return err
} func currentKeepCaps() ([]uintptr, error) {
return nil hdr := unix.CapUserHeader{Version: unix.LINUX_CAPABILITY_VERSION_3}
data := [2]unix.CapUserData{}
if err := unix.Capget(&hdr, &data[0]); err != nil {
return nil, err
}
return capsFromCapData(data), nil
}
func capsFromCapData(data [2]unix.CapUserData) []uintptr {
var caps []uintptr
for index, item := range data {
mask := item.Permitted
for bit := uint(0); bit < 32; bit++ {
if mask&(1<<bit) == 0 {
continue
}
caps = append(caps, uintptr(index*32)+uintptr(bit))
}
}
return caps
}
func mergeAmbientCaps(existing, extra []uintptr) []uintptr {
if len(existing) == 0 && len(extra) == 0 {
return nil
}
merged := append(append([]uintptr(nil), existing...), extra...)
sort.Slice(merged, func(i, j int) bool {
return merged[i] < merged[j]
})
out := merged[:0]
var last uintptr
for idx, cap := range merged {
if idx == 0 || cap != last {
out = append(out, cap)
last = cap
}
}
return out
} }

View File

@ -1,3 +1,4 @@
//go:build windows
// +build windows // +build windows
package staros package staros
@ -5,8 +6,11 @@ package staros
import ( import (
"errors" "errors"
"fmt" "fmt"
"os"
"os/exec" "os/exec"
"strconv" "strconv"
"sync/atomic"
"time"
"b612.me/wincmd" "b612.me/wincmd"
) )
@ -30,6 +34,24 @@ func FindProcessByName(pname string) (data []Process, err error) {
return return
} }
func FindProcess(compare func(Process) bool) (data []Process, err error) {
var lists []map[string]string
lists, err = wincmd.GetRunningProcess()
if err != nil {
return
}
for _, v := range lists {
var tmp Process
tmp.Name = v["name"]
tmp.Pid, _ = strconv.ParseInt(v["pid"], 10, 64)
tmp.PPid, _ = strconv.ParseInt(v["ppid"], 10, 64)
if compare(tmp) {
data = append(data, tmp)
}
}
return
}
// FindProcessByPid 通过pid来查询应用信息 // FindProcessByPid 通过pid来查询应用信息
func FindProcessByPid(pid int64) (data Process, err error) { func FindProcessByPid(pid int64) (data Process, err error) {
var lists []map[string]string var lists []map[string]string
@ -59,14 +81,117 @@ func Daemon(path string, args ...string) (int, error) {
return pid, nil return pid, nil
} }
func (starcli *StarCmd) SetRunUser(uid, gid uint32, groups []uint32) { func DaemonWithUser(uid, gid uint32, groups []uint32, path string, args ...string) (int, error) {
return -1, ERR_UNSUPPORTED
}
func (starcli *StarCmd) SetRunUser(uid, gid uint32, groups []uint32) {
_ = starcli.SetRunUserE(uid, gid, groups)
}
func (starcli *StarCmd) SetRunUserE(uid, gid uint32, groups []uint32) error {
if starcli == nil || starcli.CMD == nil {
return errNilCommand
}
if atomic.LoadInt32(&starcli.started) != 0 {
return errCommandAlreadyStarted
}
return ERR_UNSUPPORTED
} }
func (starcli *StarCmd) Release() error { func (starcli *StarCmd) Release() error {
if err := starcli.CMD.Start(); err != nil { return starcli.ReleaseE()
}
func (starcli *StarCmd) Detach() error {
return starcli.DetachE()
}
func (starcli *StarCmd) ReleaseE() error {
if starcli == nil || starcli.CMD == nil {
return errNilCommand
}
if !atomic.CompareAndSwapInt32(&starcli.released, 0, 1) {
return errCommandAlreadyReleased
}
if atomic.LoadInt32(&starcli.started) != 0 {
if starcli.CMD.Process == nil {
starcli.lock.Lock()
err := starcli.runerr
starcli.lock.Unlock()
if err != nil {
atomic.StoreInt32(&starcli.released, 0)
return err
}
atomic.StoreInt32(&starcli.released, 0)
return errCommandAlreadyStarted
}
return nil
}
if err := starcli.Start(); err != nil {
atomic.StoreInt32(&starcli.released, 0)
return err return err
} }
starcli.CMD.Process.Release()
return nil return nil
} }
func (starcli *StarCmd) DetachE() error {
if starcli == nil || starcli.CMD == nil {
return errNilCommand
}
if !atomic.CompareAndSwapInt32(&starcli.detached, 0, 1) {
return errCommandAlreadyDetached
}
if atomic.LoadInt32(&starcli.started) != 0 {
atomic.StoreInt32(&starcli.detached, 0)
return errCommandAlreadyStarted
}
cmd := exec.Command(starcli.CMD.Path, starcli.CMD.Args[1:]...)
cmd.Dir = starcli.CMD.Dir
cmd.Env = append([]string(nil), starcli.CMD.Env...)
if starcli.CMD.SysProcAttr != nil {
attr := *starcli.CMD.SysProcAttr
cmd.SysProcAttr = &attr
}
devNull, err := os.OpenFile(os.DevNull, os.O_RDWR, 0)
if err != nil {
atomic.StoreInt32(&starcli.detached, 0)
return err
}
defer devNull.Close()
cmd.Stdin = devNull
cmd.Stdout = devNull
cmd.Stderr = devNull
if err := cmd.Start(); err != nil {
atomic.StoreInt32(&starcli.detached, 0)
return err
}
starcli.CMD.Process = cmd.Process
atomic.StoreInt32(&starcli.started, 1)
starcli.setRunning(false)
starcli.finish()
if err := cmd.Process.Release(); err != nil {
atomic.StoreInt32(&starcli.detached, 0)
return err
}
return nil
}
func (starcli *StarCmd) SetKeepCaps() error {
if starcli == nil || starcli.CMD == nil {
return errNilCommand
}
return ERR_UNSUPPORTED
}
func SetKeepCaps() error {
return ERR_UNSUPPORTED
}
func CpuUsageByPid(pid int, sleep time.Duration) float64 {
return 0
}
func Whoami() (uid, gid int, uname, gname, home string, err error) {
return 0, 0, "", "", "", ERR_UNSUPPORTED
}

1716
sysconf/config.go Normal file

File diff suppressed because it is too large Load Diff

View File

@ -2,13 +2,15 @@ package sysconf
import ( import (
"bytes" "bytes"
"encoding/csv"
"errors" "errors"
"fmt" "fmt"
"reflect" "reflect"
"strconv"
"strings" "strings"
) )
var ErrNilCSVValue = errors.New("nil csv value")
type CSV struct { type CSV struct {
header []string header []string
text [][]string text [][]string
@ -24,238 +26,152 @@ type CSVValue struct {
value string value string
} }
func ParseCSV(data []byte, hasHeader bool) (csv CSV, err error) { func ParseCSV(data []byte, hasHeader bool) (csvData CSV, err error) {
strData := strings.Split(string(bytes.TrimSpace(data)), "\n") if len(data) == 0 {
if len(strData) < 1 { return CSV{}, fmt.Errorf("cannot parse data,invalid data format")
err = fmt.Errorf("cannot parse data,invalid data format")
} }
var header []string reader := csv.NewReader(bytes.NewReader(data))
var text [][]string records, err := reader.ReadAll()
if err != nil {
return CSV{}, err
}
if len(records) == 0 {
return CSV{}, fmt.Errorf("cannot parse data,invalid data format")
}
start := 0
if hasHeader { if hasHeader {
header = csvAnalyse(strData[0]) csvData.header = append([]string(nil), records[0]...)
strData = strData[1:] start = 1
} else { } else {
num := len(csvAnalyse(strData[0])) for i := range records[0] {
for i := 0; i < num; i++ { csvData.header = append(csvData.header, fmt.Sprint(i))
header = append(header, strconv.Itoa(i))
} }
} }
for k, v := range strData { for _, record := range records[start:] {
tmpData := csvAnalyse(v) if len(record) != len(csvData.header) {
if len(tmpData) != len(header) { return CSV{}, fmt.Errorf("cannot parse data line,got %d values but need %d", len(record), len(csvData.header))
err = fmt.Errorf("cannot parse data line %d,got %d values but need %d", k, len(tmpData), len(header))
return
} }
text = append(text, tmpData) csvData.text = append(csvData.text, append([]string(nil), record...))
} }
csv.header = header return csvData, nil
csv.text = text
return
} }
func (csv *CSV) Header() []string { func (csvData *CSV) Header() []string { return csvData.header }
return csv.header func (csvData *CSV) Data() [][]string { return csvData.text }
}
func (csv *CSV) Data() [][]string { func (csvData *CSV) Row(row int) *CSVRow {
return csv.text if csvData == nil || row < 0 || row >= len(csvData.text) {
}
func (csv *CSV) Row(row int) *CSVRow {
if row >= len(csv.Data()) {
return nil return nil
} }
return &CSVRow{ return &CSVRow{header: csvData.header, data: csvData.text[row]}
header: csv.Header(),
data: csv.Data()[row],
}
} }
func (csv *CSVRow) Get(key string) *CSVValue { func (row *CSVRow) Get(key string) *CSVValue {
for k, v := range csv.header { if row == nil {
if v == key { return nil
return &CSVValue{
key: key,
value: csv.data[k],
} }
for idx, header := range row.header {
if header == key {
return &CSVValue{key: key, value: row.data[idx]}
} }
} }
return nil return nil
} }
func (csv *CSVRow) Col(key int) *CSVValue { func (row *CSVRow) Col(key int) *CSVValue {
if key >= len(csv.header) { if row == nil || key < 0 || key >= len(row.header) {
return nil return nil
} }
return &CSVValue{ return &CSVValue{key: row.header[key], value: row.data[key]}
key: csv.header[key],
value: csv.data[key],
}
} }
func (csv *CSVRow) Header() []string { func (row *CSVRow) Header() []string { return row.header }
return csv.header
}
func (csv *CSV) MapData() []map[string]string { func (csvData *CSV) MapData() []map[string]string {
var result []map[string]string var result []map[string]string
for _, v := range csv.text { for _, record := range csvData.text {
tmp := make(map[string]string) item := make(map[string]string, len(csvData.header))
for k, v2 := range csv.header { for idx, header := range csvData.header {
tmp[v2] = v[k] item[header] = record[idx]
} }
result = append(result, tmp) result = append(result, item)
} }
return result return result
} }
func CsvAnalyse(data string) []string { func CsvAnalyse(data string) []string { return csvAnalyse(data) }
return csvAnalyse(data)
}
func csvAnalyse(data string) []string { func csvAnalyse(data string) []string {
var segStart bool = false reader := csv.NewReader(strings.NewReader(data))
var segReady bool = false record, err := reader.Read()
var segSign string = "" if err != nil {
var dotReady bool = false return []string{}
data = strings.TrimSpace(data)
var result []string
var seg string
for k, v := range []rune(data) {
if k == 0 && v != []rune(`"`)[0] {
dotReady = true
} }
if v != []rune(`,`)[0] && dotReady { return record
segSign = `,`
segStart = true
dotReady = false
if v == []rune(`"`)[0] {
segSign = `"`
continue
}
}
if dotReady && v == []rune(`,`)[0] {
//dotReady = false
result = append(result, "")
continue
}
if v == []rune(`"`)[0] && segStart {
if !segReady {
segReady = true
continue
}
seg += `"`
segReady = false
continue
}
if segReady && segSign == `"` && segStart {
segReady = false
segStart = false
result = append(result, seg)
segSign = ``
seg = ""
}
if v == []rune(`"`)[0] && !segStart {
segStart = true
segReady = false
segSign = `"`
continue
}
if v == []rune(`,`)[0] && !segStart {
dotReady = true
}
if v == []rune(`,`)[0] && segStart && segSign == "," {
segStart = false
result = append(result, seg)
dotReady = true
segSign = ``
seg = ""
}
if segStart {
seg = string(append([]rune(seg), v))
}
}
if len(data) != 0 && len(result) == 0 && seg == "" {
result = append(result, data)
} else {
result = append(result, seg)
}
return result
} }
func MarshalCSV(header []string, ins interface{}) ([]byte, error) { func MarshalCSV(header []string, ins interface{}) ([]byte, error) {
var result [][]string
t := reflect.TypeOf(ins)
v := reflect.ValueOf(ins) v := reflect.ValueOf(ins)
if v.Kind() != reflect.Slice && v.Kind() != reflect.Array { if v.Kind() == reflect.Ptr {
return nil, errors.New("not a Slice or Array") if v.IsNil() {
return nil, ErrNilCSVValue
} }
if t.Kind() == reflect.Ptr {
t = t.Elem()
v = v.Elem() v = v.Elem()
} }
if v.Kind() != reflect.Slice && v.Kind() != reflect.Array {
return nil, fmt.Errorf("not a Slice or Array")
}
rows := make([][]string, 0, v.Len())
for i := 0; i < v.Len(); i++ { for i := 0; i < v.Len(); i++ {
subT := reflect.TypeOf(v.Index(i).Interface()) item := v.Index(i)
subV := reflect.ValueOf(v.Index(i).Interface()) if item.Kind() == reflect.Ptr {
if subV.Kind() == reflect.Slice || subV.Kind() == reflect.Array { if item.IsNil() {
if subT.Kind() == reflect.Ptr { continue
subV = subV.Elem()
} }
var tmp []string item = item.Elem()
for j := 0; j < subV.Len(); j++ {
tmp = append(tmp, fmt.Sprint(reflect.ValueOf(subV.Index(j))))
} }
result = append(result, tmp) switch item.Kind() {
case reflect.Slice, reflect.Array:
row := make([]string, 0, item.Len())
for j := 0; j < item.Len(); j++ {
row = append(row, fmt.Sprint(item.Index(j).Interface()))
} }
if subV.Kind() == reflect.Struct { rows = append(rows, row)
var tmp []string case reflect.Struct:
if subT.Kind() == reflect.Ptr { row := make([]string, 0, item.NumField())
subV = subV.Elem() for j := 0; j < item.NumField(); j++ {
field := item.Field(j)
if !field.CanInterface() {
continue
} }
for i := 0; i < subV.NumField(); i++ { row = append(row, fmt.Sprint(field.Interface()))
tmp = append(tmp, fmt.Sprint(subV.Field(i)))
} }
result = append(result, tmp) rows = append(rows, row)
} }
} }
width := 0
return buildCSV(header,result) if len(header) > 0 {
} width = len(header)
} else if len(rows) > 0 {
func buildCSV(header []string, data [][]string) ([]byte, error) { width = len(rows[0])
var result []string }
var length int for idx, row := range rows {
build := func(slc []string) string { if len(row) != width {
for k, v := range slc { return nil, fmt.Errorf("line %d got length %d ,but need %d", idx, len(row), width)
if strings.Index(v, `"`) >= 0 { }
v = strings.ReplaceAll(v, `"`, `""`) }
} var buf bytes.Buffer
if strings.Index(v,"\n")>=0 { writer := csv.NewWriter(&buf)
v=strings.ReplaceAll(v,"\n",`\n`) if len(header) > 0 {
} if err := writer.Write(header); err != nil {
if strings.Index(v,"\r")>=0 { return nil, err
v=strings.ReplaceAll(v,"\r",`\r`) }
} }
v = `"` + v + `"` for _, row := range rows {
slc[k] = v if err := writer.Write(row); err != nil {
} return nil, err
return strings.Join(slc, ",") }
} }
if len(header) != 0 { writer.Flush()
result = append(result, build(header)) return buf.Bytes(), writer.Error()
length = len(header)
} else {
length = len(data[0])
}
for k, v := range data {
if len(v) != length {
return nil, fmt.Errorf("line %d got length %d ,but need %d", k, len(v), length)
}
result = append(result, build(v))
}
return []byte(strings.Join(result, "\n")), nil
} }

View File

@ -1,38 +0,0 @@
package sysconf
import (
"fmt"
"testing"
)
func Test_csv(t *testing.T) {
//var test Sqlplus
var text=`
姓名,班级,性别,年龄
张三,"我,不""知道",boy,23
"里斯","哈哈",girl,23
`
fmt.Println(csvAnalyse(`请求权,lkjdshck,dsvdsv,"sdvkjsdv,",=dsvdsv,"=,dsvsdv"`))
a,b:=ParseCSV([]byte(text),true)
fmt.Println(b)
fmt.Println(a.Row(0).Col(3).MustInt())
}
type csvtest struct {
A string
B int
}
func Test_Masharl(t *testing.T) {
//var test Sqlplus
/*
var a []csvtest = []csvtest{
{"lala",1},
{"haha",34},
}
*/
var a [][]string
a=append(a,[]string{"a","b","c"})
a=append(a,[]string{"1",`s"s"d`,"3"})
b,_:=MarshalCSV([]string{},a)
fmt.Println(string(b))
}

View File

@ -1,120 +0,0 @@
package sysconf
import "strconv"
func (csv *CSVValue)Key()string {
return csv.key
}
func (csv *CSVValue)Int()(int,error) {
tmp,err:=strconv.Atoi(csv.value)
return tmp,err
}
func (csv *CSVValue)MustInt()int {
tmp,err:=csv.Int()
if err!=nil {
panic(err)
}
return tmp
}
func (csv *CSVValue)Int64()(int64,error) {
tmp,err:=strconv.ParseInt(csv.value,10,64)
return tmp,err
}
func (csv *CSVValue)MustInt64()int64 {
tmp,err:=csv.Int64()
if err!=nil {
panic(err)
}
return tmp
}
func (csv *CSVValue)Int32()(int32,error) {
tmp,err:=strconv.ParseInt(csv.value,10,32)
return int32(tmp),err
}
func (csv *CSVValue)MustInt32()int32 {
tmp,err:=csv.Int32()
if err!=nil {
panic(err)
}
return tmp
}
func (csv *CSVValue)Uint64()(uint64,error) {
tmp,err:=strconv.ParseUint(csv.value,10,64)
return tmp,err
}
func (csv *CSVValue)MustUint64()uint64 {
tmp,err:=csv.Uint64()
if err!=nil {
panic(err)
}
return tmp
}
func (csv *CSVValue)Uint32()(uint32,error) {
tmp,err:=strconv.ParseUint(csv.value,10,32)
return uint32(tmp),err
}
func (csv *CSVValue)MustUint32()uint32 {
tmp,err:=csv.Uint32()
if err!=nil {
panic(err)
}
return tmp
}
func (csv *CSVValue)String()string {
return csv.value
}
func (csv *CSVValue)Byte()[]byte {
return []byte(csv.value)
}
func (csv *CSVValue)Bool()(bool,error) {
tmp,err:=strconv.ParseBool(csv.value)
return tmp,err
}
func (csv *CSVValue)MustBool()bool {
tmp,err:=csv.Bool()
if err!=nil {
panic(err)
}
return tmp
}
func (csv *CSVValue)Float64()(float64,error) {
tmp,err:=strconv.ParseFloat(csv.value,64)
return tmp,err
}
func (csv *CSVValue)MustFloat64()float64 {
tmp,err:=csv.Float64()
if err!=nil {
panic(err)
}
return tmp
}
func (csv *CSVValue)Float32()(float32,error) {
tmp,err:=strconv.ParseFloat(csv.value,32)
return float32(tmp),err
}
func (csv *CSVValue)MustFloat32()float32 {
tmp,err:=csv.Float32()
if err!=nil {
panic(err)
}
return tmp
}

1192
sysconf/document.go Normal file

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,27 @@
package sysconf
import (
"fmt"
"strings"
)
func ExampleNewIni_migration() {
ini := NewIni()
_ = ini.Parse([]byte("[app]\nport=8080\nfeature=alpha\nfeature=beta\n"))
app := ini.Section("app")
_ = app.SetInt("port", 9090, "")
_ = app.SetAll("feature", []string{"stable", "audit"}, "")
fmt.Println(ini.Get("app", "port"))
fmt.Println(ini.GetAll("app", "feature"))
fmt.Println(strings.TrimSpace(string(ini.Build())))
// Output:
// 9090
// [stable audit]
// [app]
// port=9090
// feature=stable
// feature=audit
}

445
sysconf/ini.go Normal file
View File

@ -0,0 +1,445 @@
package sysconf
import (
"errors"
"fmt"
"os"
"reflect"
"sort"
)
type Ini struct {
*Document
}
type IniProfile func(*Ini)
func NewIni() *Ini {
return &Ini{Document: NewDocument()}
}
func NewIniWithProfiles(profiles ...IniProfile) *Ini {
ini := NewIni()
for _, profile := range profiles {
if profile != nil {
profile(ini)
}
}
return ini
}
func DefaultINIProfile() IniProfile {
return func(ini *Ini) {
if ini == nil {
return
}
ini.Document = NewDocument()
}
}
func StrictINIProfile() IniProfile {
return func(ini *Ini) {
if ini == nil {
return
}
if ini.Document == nil {
ini.Document = NewDocument()
}
ini.Strict = true
ini.AllowNoValue = false
}
}
func LinuxConfProfile(equal string) IniProfile {
return func(ini *Ini) {
if ini == nil {
return
}
if ini.Document == nil {
ini.Document = NewDocument()
}
ini.SectionOpen = ""
ini.SectionClose = ""
ini.CommentHeads = []string{"#"}
if equal != "" {
ini.Assign = equal
}
ini.AssignDelimiters = []string{ini.Assign}
}
}
func (i *Ini) ApplyProfile(profile IniProfile) *Ini {
if profile != nil {
profile(i)
}
return i
}
func NewSysConf(equal string) *Ini {
ini := NewIni()
if equal != "" {
ini.Assign = equal
}
ini.AssignDelimiters = []string{ini.Assign}
return ini
}
func NewLinuxConf(equal string) *Ini {
return NewIniWithProfiles(LinuxConfProfile(equal))
}
func (i *Ini) Parse(data []byte) error {
if i == nil || i.Document == nil {
return ErrDocumentClosed
}
return i.Document.Parse(data)
}
func (i *Ini) ParseFromFile(path string) error {
data, err := os.ReadFile(path)
if err != nil {
return err
}
return i.Parse(data)
}
func (i *Ini) Build() []byte {
if i == nil || i.Document == nil {
return nil
}
return i.Document.Bytes()
}
func (i *Ini) Save(path string) error {
if i == nil || i.Document == nil {
return ErrDocumentClosed
}
return i.Document.Save(path)
}
func (i *Ini) SaveAtomic(path string) error {
if i == nil || i.Document == nil {
return ErrDocumentClosed
}
return i.Document.SaveAtomic(path)
}
func (i *Ini) Section(name string) *Section {
if i == nil || i.Document == nil {
return nil
}
return i.Document.Section(name)
}
func (i *Ini) Sections(name string) []*Section {
if i == nil || i.Document == nil {
return nil
}
return i.Document.SectionsByName(name)
}
func (i *Ini) AddSection(name string) *Section {
if i == nil || i.Document == nil {
return nil
}
i.Document.mu.Lock()
defer i.Document.mu.Unlock()
return i.Document.appendSection(name, "", "", "\n")
}
func (i *Ini) DeleteSection(name string) bool {
if i == nil || i.Document == nil {
return false
}
i.Document.mu.Lock()
defer i.Document.mu.Unlock()
i.Document.rebuildSectionIndexLocked()
normalized := normalize(name, i.CaseSensitive)
sections := i.Document.sectionIndex[normalized]
if len(sections) == 0 {
return false
}
delete(i.Document.sectionIndex, normalized)
filtered := i.Document.sections[:0]
for _, section := range i.Document.sections {
if normalize(section.Name, i.CaseSensitive) == normalized {
continue
}
filtered = append(filtered, section)
}
i.Document.sections = filtered
return true
}
func (i *Ini) Get(section, key string) string {
for _, s := range i.Sections(section) {
if s != nil && s.Exist(key) {
return s.Get(key)
}
}
return ""
}
func (i *Ini) GetAll(section, key string) []string {
sections := i.Sections(section)
if len(sections) == 0 {
return nil
}
values := make([]string, 0)
for _, s := range sections {
if s == nil {
continue
}
values = append(values, s.GetAll(key)...)
}
if len(values) == 0 {
return nil
}
return values
}
func (i *Ini) Has(section, key string) bool {
for _, s := range i.Sections(section) {
if s != nil && s.Exist(key) {
return true
}
}
return false
}
func (i *Ini) Set(section, key, value string) {
if i == nil || i.Document == nil {
return
}
i.Document.mu.Lock()
s := i.Document.ensureSection(section)
i.Document.mu.Unlock()
if s != nil {
_ = s.Set(key, value, "")
}
}
func (i *Ini) AddValue(section, key, value string) {
if i == nil || i.Document == nil {
return
}
i.Document.mu.Lock()
s := i.Document.ensureSection(section)
i.Document.mu.Unlock()
if s != nil {
_ = s.AddValue(key, value, "")
}
}
func (i *Ini) Delete(section, key string) bool {
if s := i.Section(section); s != nil {
return s.Delete(key) == nil
}
return false
}
func (i *Ini) SectionsMap() map[string][]*Section {
if i == nil || i.Document == nil {
return nil
}
i.Document.mu.Lock()
defer i.Document.mu.Unlock()
i.Document.rebuildSectionIndexLocked()
out := make(map[string][]*Section, len(i.Document.sectionIndex))
for name, sections := range i.Document.sectionIndex {
out[name] = append([]*Section(nil), sections...)
}
return out
}
func (i *Ini) Unmarshal(dst interface{}) error {
return bindINI(i, dst)
}
func (i *Ini) Marshal(src interface{}) ([]byte, error) {
tmp := newIniLike(i)
if err := marshalINI(tmp, src); err != nil {
return nil, err
}
return tmp.Build(), nil
}
func bindINI(i *Ini, dst interface{}) error {
if dst == nil {
return errors.New("destination is nil")
}
v := reflect.ValueOf(dst)
if v.Kind() != reflect.Ptr || v.IsNil() {
return errors.New("destination must be a non-nil pointer")
}
v = v.Elem()
if v.Kind() != reflect.Struct {
return errors.New("destination must point to a struct")
}
return bindStruct(i, v, "")
}
func bindStruct(i *Ini, v reflect.Value, inheritedSection string) error {
t := v.Type()
for idx := 0; idx < t.NumField(); idx++ {
field := t.Field(idx)
value := v.Field(idx)
if !value.CanSet() {
continue
}
section := field.Tag.Get("seg")
key := field.Tag.Get("key")
if section == "" {
section = inheritedSection
}
if key == "-" {
continue
}
if isNestedConfigStruct(value, key) {
nested := value
for nested.Kind() == reflect.Ptr {
if nested.IsNil() {
nested.Set(reflect.New(nested.Type().Elem()))
}
nested = nested.Elem()
}
if err := bindStruct(i, nested, section); err != nil {
return err
}
continue
}
if key == "" {
continue
}
items := configValuesFromSections(i.Sections(section), key)
if len(items) == 0 {
continue
}
if err := setINIField(value, items); err != nil {
return err
}
}
return nil
}
func setINIField(value reflect.Value, items []configValue) error {
return setConfigValueItems(value, items)
}
func marshalINI(dst *Ini, src interface{}) error {
v := reflect.ValueOf(src)
if !v.IsValid() {
return errors.New("nil source")
}
if v.Kind() == reflect.Ptr {
if v.IsNil() {
return errors.New("nil source")
}
v = v.Elem()
}
if v.Kind() != reflect.Struct {
return errors.New("source must be struct")
}
return marshalStruct(dst, v, "")
}
func marshalStruct(dst *Ini, v reflect.Value, inheritedSection string) error {
t := v.Type()
for i := 0; i < t.NumField(); i++ {
field := t.Field(i)
fv := v.Field(i)
if !fv.CanInterface() {
continue
}
section := field.Tag.Get("seg")
if section == "" {
section = inheritedSection
}
key := field.Tag.Get("key")
comment := field.Tag.Get("comment")
if key == "-" {
continue
}
if nested, ok := nestedConfigValueForWrite(fv, key); ok {
if nested.IsValid() {
if err := marshalStruct(dst, nested, section); err != nil {
return err
}
}
continue
}
if key == "" {
continue
}
if err := setINIValue(dst, section, key, fv); err != nil {
return err
}
if comment != "" {
sec := dst.Section(section)
if sec == nil {
sec = dst.AddSection(section)
}
if sec != nil {
_ = sec.SetComment(key, comment)
}
}
}
return nil
}
func marshalSection(dst *Ini, section string, value reflect.Value) error {
return marshalStruct(dst, value, section)
}
func setINIValue(dst *Ini, section, key string, value reflect.Value) error {
for value.Kind() == reflect.Ptr {
if value.IsNil() {
return nil
}
value = value.Elem()
}
switch value.Kind() {
case reflect.Slice, reflect.Array:
if value.Type().Elem().Kind() != reflect.String {
dst.Set(section, key, fmt.Sprint(value.Interface()))
return nil
}
sec := dst.Section(section)
if sec == nil {
sec = dst.AddSection(section)
}
if sec == nil {
return ErrDocumentClosed
}
values := make([]string, 0, value.Len())
for idx := 0; idx < value.Len(); idx++ {
values = append(values, value.Index(idx).String())
}
return sec.SetAll(key, values, "")
case reflect.Map:
if value.Type().Key().Kind() != reflect.String || value.Type().Elem().Kind() != reflect.String {
dst.Set(section, key, fmt.Sprint(value.Interface()))
return nil
}
keys := make([]string, 0, value.Len())
for _, mapKey := range value.MapKeys() {
keys = append(keys, mapKey.String())
}
sort.Strings(keys)
values := make([]string, 0, len(keys))
for _, mapKey := range keys {
values = append(values, mapKey+"="+value.MapIndex(reflect.ValueOf(mapKey)).String())
}
sec := dst.Section(section)
if sec == nil {
sec = dst.AddSection(section)
}
if sec == nil {
return ErrDocumentClosed
}
return sec.SetAll(key, values, "")
default:
dst.Set(section, key, fmt.Sprint(value.Interface()))
return nil
}
}

View File

@ -1,855 +0,0 @@
package sysconf
import (
"bytes"
"errors"
"fmt"
"io/ioutil"
"reflect"
"strconv"
"strings"
"sync"
)
type SysConf struct {
Data []*SysSegment
segmap map[string]int64
segId int64
HaveSegMent bool //是否有节这个概念
SegStart string
SegEnd string
CommentFlag []string //评论标识符,如#
EqualFlag string //赋值标识符,如=
ValueFlag string //值标识符,如"
EscapeFlag string //转义字符
CommentCR bool //评论是否能与value同一行true不行false可以
SpaceStr string //美化符号
lock sync.RWMutex
}
type SysSegment struct {
Name string
//nodeMap
Comment string
NodeData []*SysNode
nodeId int64
nodeMap map[string]int64
lock sync.RWMutex
}
type SysNode struct {
Key string
Value []string
Comment string
NoValue bool
lock sync.RWMutex
}
func NewIni() *SysConf {
ini := NewSysConf("=")
ini.CommentCR = true
ini.CommentFlag = []string{"#", ";"}
ini.HaveSegMent = true
ini.SegStart = "["
ini.SegEnd = "]"
ini.SpaceStr = " "
ini.EscapeFlag = "\\"
return ini
}
func NewSysConf(EqualFlag string) *SysConf {
syscnf := new(SysConf)
syscnf.EqualFlag = EqualFlag
return syscnf
}
// NewLinuxConf sysctl.conf like file
func NewLinuxConf(EqualFlag string) *SysConf {
syscnf := new(SysConf)
syscnf.EqualFlag = EqualFlag
syscnf.HaveSegMent = false
syscnf.CommentCR = true
syscnf.CommentFlag = []string{"#"}
return syscnf
}
func (syscfg *SysConf) ParseFromFile(filepath string) error {
data, err := ioutil.ReadFile(filepath)
if err != nil {
return err
}
return syscfg.Parse(data)
}
// Parse 生成INI文件结构
func (syscfg *SysConf) Parse(data []byte) error {
syscfg.lock.Lock()
defer syscfg.lock.Unlock()
if syscfg.HaveSegMent && (syscfg.SegStart == "" || syscfg.SegEnd == "") {
return errors.New("SegMent Start or End Flag Not Allowed!")
}
if !syscfg.CommentCR {
return syscfg.parseNOCRComment(data)
}
return syscfg.parseCRComment(data)
}
func (syscfg *SysConf) parseNOCRComment(data []byte) error { //允许comment在同一行
data = bytes.TrimSpace(data)
dataLists := bytes.Split(data, []byte("\n"))
seg := new(SysSegment)
seg.nodeMap = make(map[string]int64)
syscfg.segmap = make(map[string]int64)
if syscfg.HaveSegMent {
seg.Name = "unnamed"
}
syscfg.segId = 0
var node *SysNode
for _, v1 := range dataLists {
var (
isSegStart bool = false
isEscape bool = false
isEqual bool = false
isComment bool = false
tsuMo string = ""
)
cowStr := strings.TrimSpace(string(v1))
for i := 0; i < len(cowStr); i++ {
runeStr := cowStr[i : i+1] //当前字符rune扫描
if runeStr == syscfg.EscapeFlag && (!isEscape) {
isEscape = true
continue
}
if runeStr == syscfg.SegStart && (!isEscape) {
isSegStart = true
continue
}
if runeStr == syscfg.SegEnd && (!isEscape) {
isSegStart = false
//New segment start from here
if seg.Name == "unnamed" && len(seg.NodeData) == 0 {
seg.Name = tsuMo
tsuMo = ""
continue
}
syscfg.segmap[seg.Name] = syscfg.segId
syscfg.segId++
syscfg.Data = append(syscfg.Data, seg)
seg = new(SysSegment)
seg.nodeMap = make(map[string]int64)
seg.Name = tsuMo
tsuMo = ""
continue
}
if isSegStart {
tsuMo += runeStr
if isEscape {
isEscape = false
}
continue
}
if syscfg.EqualFlag == runeStr && (!isEscape) && (!isEqual) {
key := strings.TrimSpace(tsuMo)
if val, ok := seg.nodeMap[key]; ok {
node = seg.NodeData[val]
} else {
node = new(SysNode)
node.Key = key
seg.nodeMap[node.Key] = seg.nodeId
seg.nodeId++
seg.NodeData = append(seg.NodeData, node)
}
tsuMo = ""
isEqual = true
if syscfg.ValueFlag != "" {
nokoriStr := strings.TrimSpace(cowStr[i+1:])
isFound := false
isValue := false
for k4, v4 := range nokoriStr {
if string([]rune{v4}) == syscfg.ValueFlag {
isValue = !isValue
}
if SliceIn(syscfg.CommentFlag, string([]rune{v4})) && !isValue {
val := nokoriStr[:k4]
isFound = true
startFinder := strings.Index(val, syscfg.ValueFlag)
endFinder := strings.LastIndex(val, syscfg.ValueFlag)
if !((startFinder == -1 || endFinder == -1) || (endFinder-startFinder <= 0)) {
node.Value = append(node.Value, strings.TrimSpace(val[startFinder+1:endFinder]))
}
node.Comment = nokoriStr[k4+1:] + "\n"
}
}
if !isFound {
startFinder := strings.Index(nokoriStr, syscfg.ValueFlag)
endFinder := strings.LastIndex(nokoriStr, syscfg.ValueFlag)
if (startFinder == -1 || endFinder == -1) || (endFinder-startFinder <= 0) {
break
}
node.Value = append(node.Value, strings.TrimSpace(nokoriStr[startFinder+1:endFinder]))
}
break
}
continue
}
if SliceIn(syscfg.CommentFlag, runeStr) && (!isEscape) {
isComment = true
if seg.nodeId == 0 {
seg.Comment += strings.TrimSpace(cowStr[i+1:]) + "\n"
break
}
if tsuMo != "" {
node.Value = append(node.Value, strings.TrimSpace(tsuMo))
tsuMo = ""
}
node.Comment += strings.TrimSpace(cowStr[i+1:]) + "\n"
break
}
isEscape = false
tsuMo += runeStr
}
if isEqual && tsuMo != "" {
node.Value = append(node.Value, strings.TrimSpace(tsuMo))
}
if !isEqual && tsuMo != "" && !isComment {
node = new(SysNode)
node.Key = tsuMo
seg.nodeMap[node.Key] = seg.nodeId
seg.nodeId++
seg.NodeData = append(seg.NodeData, node)
node.NoValue = true
}
}
if seg != nil {
syscfg.segmap[seg.Name] = syscfg.segId
syscfg.segId++
syscfg.Data = append(syscfg.Data, seg)
}
return nil
}
func (syscfg *SysConf) parseCRComment(data []byte) error { //不允许comment在同一行
data = bytes.TrimSpace(data)
dataLists := bytes.Split(data, []byte("\n"))
seg := new(SysSegment)
seg.nodeMap = make(map[string]int64)
syscfg.segmap = make(map[string]int64)
if syscfg.HaveSegMent {
seg.Name = "unnamed"
}
syscfg.segId = 0
var node *SysNode
for _, v1 := range dataLists {
var (
isSegStart bool = false
isEscape bool = false
isEqual bool = false
isComment bool = false
tsuMo string = ""
)
cowStr := strings.TrimSpace(string(v1))
for i := 0; i < len(cowStr); i++ {
runeStr := cowStr[i : i+1] //当前字符rune扫描
if runeStr == syscfg.EscapeFlag && (!isEscape) {
isEscape = true
continue
}
if runeStr == syscfg.SegStart && (!isEscape) {
isSegStart = true
continue
}
if runeStr == syscfg.SegEnd && (!isEscape) {
isSegStart = false
//New segment start from here
if seg.Name == "unnamed" && len(seg.NodeData) == 0 {
seg.Name = tsuMo
tsuMo = ""
break
}
syscfg.segmap[seg.Name] = syscfg.segId
syscfg.segId++
syscfg.Data = append(syscfg.Data, seg)
seg = new(SysSegment)
seg.nodeMap = make(map[string]int64)
seg.Name = tsuMo
tsuMo = ""
break
}
if isSegStart {
tsuMo += runeStr
if isEscape {
isEscape = false
}
continue
}
if syscfg.EqualFlag == runeStr && (!isEscape) && (!isEqual) {
key := strings.TrimSpace(tsuMo)
if val, ok := seg.nodeMap[key]; ok {
node = seg.NodeData[val]
} else {
node = new(SysNode)
node.Key = key
seg.nodeMap[node.Key] = seg.nodeId
seg.nodeId++
seg.NodeData = append(seg.NodeData, node)
}
tsuMo = ""
isEqual = true
if syscfg.ValueFlag == "" {
node.Value = append(node.Value, TrimEscape(strings.TrimSpace(cowStr[i+1:]), syscfg.EscapeFlag))
} else {
nokoriStr := strings.TrimSpace(cowStr[i+1:])
startFinder := strings.Index(nokoriStr, syscfg.ValueFlag)
endFinder := strings.LastIndex(nokoriStr, syscfg.ValueFlag)
if (startFinder == -1 || endFinder == -1) || (endFinder-startFinder <= 0) {
break
}
node.Value = append(node.Value, strings.TrimSpace(nokoriStr[startFinder+1:endFinder]))
}
break
}
if SliceIn(syscfg.CommentFlag, runeStr) && (!isEscape) {
isComment = true
tsuMo = ""
if seg.nodeId == 0 {
seg.Comment += strings.TrimSpace(cowStr[i+1:]) + "\n"
break
}
node.Comment += strings.TrimSpace(cowStr[i+1:]) + "\n"
break
}
isEscape = false
tsuMo += runeStr
}
if !isEqual && tsuMo != "" && !isComment {
node = new(SysNode)
node.Key = tsuMo
seg.nodeMap[node.Key] = seg.nodeId
seg.nodeId++
seg.NodeData = append(seg.NodeData, node)
node.NoValue = true
}
}
if seg != nil {
syscfg.segmap[seg.Name] = syscfg.segId
syscfg.segId++
syscfg.Data = append(syscfg.Data, seg)
}
return nil
}
func (syscfg *SysConf) Build() []byte {
syscfg.lock.Lock()
defer syscfg.lock.Unlock()
var outPut string
for _, v := range syscfg.Data {
if v == nil {
continue
}
if syscfg.HaveSegMent {
outPut += syscfg.SegStart + v.Name + syscfg.SegEnd + "\n"
}
if v.Comment != "" {
v.Comment = v.Comment[:len(v.Comment)-1]
comment := strings.Split(v.Comment, "\n")
for _, vc := range comment {
if vc != "" {
outPut += syscfg.CommentFlag[0] + vc + "\n"
} else {
outPut += "\n"
}
}
}
for _, v2 := range v.NodeData {
if v2 == nil {
continue
}
if v2.NoValue {
outPut += v2.Key + "\n"
} else {
for _, v3 := range v2.Value {
if syscfg.ValueFlag != "" {
outPut += v2.Key + syscfg.SpaceStr + syscfg.EqualFlag + syscfg.SpaceStr + syscfg.ValueFlag + v3 + syscfg.ValueFlag + "\n"
} else {
outPut += v2.Key + syscfg.SpaceStr + syscfg.EqualFlag + syscfg.SpaceStr + syscfg.addEscape(v3) + "\n"
}
}
if len(v2.Value) == 0 {
outPut += v2.Key + syscfg.SpaceStr + syscfg.EqualFlag + "\n"
}
if v2.Comment != "" {
v2.Comment = v2.Comment[:len(v2.Comment)-1]
comment := strings.Split(v2.Comment, "\n")
for _, vc := range comment {
if vc != "" {
outPut += syscfg.CommentFlag[0] + vc + "\n"
} else {
outPut += "\n"
}
}
}
}
}
}
return []byte(outPut)
}
func (syscfg *SysConf) addEscape(str string) string {
str = strings.ReplaceAll(str, syscfg.EscapeFlag, syscfg.EscapeFlag+syscfg.EscapeFlag)
str = strings.ReplaceAll(str, syscfg.EqualFlag, syscfg.EscapeFlag+syscfg.EqualFlag)
str = strings.ReplaceAll(str, syscfg.SegStart, syscfg.EscapeFlag+syscfg.SegStart)
str = strings.ReplaceAll(str, syscfg.SegEnd, syscfg.EscapeFlag+syscfg.SegEnd)
for _, v := range syscfg.CommentFlag {
str = strings.ReplaceAll(str, v, syscfg.EscapeFlag+v)
}
return str
}
func (syscfg *SysConf) Reverse() {
syscfg.lock.Lock()
defer syscfg.lock.Unlock()
for _, v := range syscfg.Data {
if v == nil {
continue
}
var (
NodeData []*SysNode
nodeId int64
nodeMap map[string]int64
)
nodeMap = make(map[string]int64)
for _, v2 := range v.NodeData {
if v2 == nil {
continue
}
for _, v3 := range v2.Value {
var node *SysNode
if val, ok := nodeMap[v3]; ok {
node = NodeData[val]
} else {
node = new(SysNode)
node.Key = v3
NodeData = append(NodeData, node)
nodeMap[v3] = nodeId
nodeId++
}
node.Value = append(node.Value, strings.TrimSpace(v2.Key))
node.Comment += v2.Comment
}
v.NodeData = NodeData
v.nodeId = nodeId
v.nodeMap = nodeMap
}
}
}
func TrimEscape(text, escape string) string {
var isEscape bool = false
var outPut []rune
if escape == "" {
return text
}
text = strings.TrimSpace(text)
for _, v := range text {
if v == []rune(escape)[0] && !isEscape {
isEscape = true
continue
}
outPut = append(outPut, v)
}
return string(outPut)
}
func SliceIn(slice interface{}, data interface{}) bool {
typed := reflect.ValueOf(slice)
if typed.Kind() == reflect.Slice || typed.Kind() == reflect.Array {
for i := 0; i < typed.Len(); i++ {
if typed.Index(i).Interface() == data {
return true
}
}
}
return false
}
// Unmarshal 输出结果到结构体中
func (cfg *SysConf) Unmarshal(ins interface{}) error {
var structSet func(t reflect.Type, v reflect.Value, oriSeg string) error
t := reflect.TypeOf(ins)
v := reflect.ValueOf(ins).Elem()
if v.Kind() != reflect.Struct {
return errors.New("Not a Struct")
}
if t.Kind() != reflect.Ptr || !v.CanSet() {
return errors.New("Cannot Write!")
}
t = t.Elem()
structSet = func(t reflect.Type, v reflect.Value, oriSeg string) error {
for i := 0; i < t.NumField(); i++ {
tp := t.Field(i)
vl := v.Field(i)
if !vl.CanSet() {
continue
}
if vl.Type().Kind() == reflect.Struct {
structSet(vl.Type(), vl, tp.Tag.Get("seg"))
continue
}
seg := tp.Tag.Get("seg")
key := tp.Tag.Get("key")
if key != "" && seg == "" && cfg.HaveSegMent {
seg = "unnamed"
}
if oriSeg != "" {
seg = oriSeg
}
if seg == "" || key == "" {
continue
}
if _, ok := cfg.segmap[seg]; !ok {
continue
}
segs := cfg.Data[cfg.segmap[seg]]
if segs.Get(key) == "" {
continue
}
switch vl.Kind() {
case reflect.String:
vl.SetString(segs.Get(key))
case reflect.Int, reflect.Int32, reflect.Int64:
vl.SetInt(segs.Int64(key))
case reflect.Float32, reflect.Float64:
vl.SetFloat(segs.Float64(key))
case reflect.Bool:
vl.SetBool(segs.Bool(key))
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
vl.SetUint(uint64(segs.Int64(key)))
default:
continue
}
}
return nil
}
return structSet(t, v, "")
}
// Marshal 输出结果到结构体中
func (cfg *SysConf) Marshal(ins interface{}) ([]byte, error) {
var structSet func(t reflect.Type, v reflect.Value, oriSeg string)
t := reflect.TypeOf(ins)
v := reflect.ValueOf(ins)
if v.Kind() != reflect.Struct {
return nil, errors.New("Not a Struct")
}
if t.Kind() == reflect.Ptr {
t = t.Elem()
v = v.Elem()
}
structSet = func(t reflect.Type, v reflect.Value, oriSeg string) {
for i := 0; i < t.NumField(); i++ {
var seg, key, comment string = "", "", ""
tp := t.Field(i)
vl := v.Field(i)
if vl.Type().Kind() == reflect.Struct {
structSet(vl.Type(), vl, tp.Tag.Get("seg"))
continue
}
seg = tp.Tag.Get("seg")
key = tp.Tag.Get("key")
comment = tp.Tag.Get("comment")
if oriSeg != "" {
seg = oriSeg
}
if seg == "" || key == "" {
continue
}
if _, ok := cfg.segmap[seg]; !ok {
cfg.AddSeg(seg)
}
cfg.Seg(seg).Set(key, fmt.Sprint(vl), comment)
}
}
structSet(t, v, "")
return cfg.Build(), nil
}
func (syscfg *SysConf) Seg(name string) *SysSegment {
if _, ok := syscfg.segmap[name]; !ok {
return nil
}
seg := syscfg.Data[syscfg.segmap[name]]
return seg
}
func (syscfg *SysConf) AddSeg(name string) *SysSegment {
syscfg.lock.Lock()
defer syscfg.lock.Unlock()
if _, ok := syscfg.segmap[name]; !ok {
newseg := new(SysSegment)
newseg.Name = name
newseg.nodeMap = make(map[string]int64)
syscfg.Data = append(syscfg.Data, newseg)
syscfg.segId++
if syscfg.segmap == nil {
syscfg.segId = 0
syscfg.segmap = make(map[string]int64)
}
syscfg.segmap[newseg.Name] = syscfg.segId
return newseg
}
seg := syscfg.Data[syscfg.segmap[name]]
return seg
}
func (syscfg *SysConf) DeleteSeg(name string) error {
syscfg.lock.Lock()
defer syscfg.lock.Unlock()
if _, ok := syscfg.segmap[name]; !ok {
return errors.New("Seg Not Exists")
}
syscfg.Data[syscfg.segmap[name]] = nil
delete(syscfg.segmap, name)
return nil
}
func (syscfg *SysSegment) GetComment(key string) string {
if v, ok := syscfg.nodeMap[key]; !ok {
return ""
} else {
return syscfg.NodeData[v].Comment
}
}
func (syscfg *SysSegment) SetComment(key, comment string) error {
syscfg.lock.Lock()
defer syscfg.lock.Unlock()
if v, ok := syscfg.nodeMap[key]; !ok {
return errors.New("Key Not Exists")
} else {
syscfg.NodeData[v].Comment = comment
return nil
}
}
func (syscfg *SysSegment) Exist(key string) bool {
if _, ok := syscfg.nodeMap[key]; !ok {
return false
} else {
return true
}
}
func (syscfg *SysSegment) Get(key string) string {
if v, ok := syscfg.nodeMap[key]; !ok {
return ""
} else {
if len(syscfg.NodeData[v].Value) >= 1 {
return syscfg.NodeData[v].Value[0]
}
}
return ""
}
func (syscfg *SysSegment) GetAll(key string) []string {
if v, ok := syscfg.nodeMap[key]; !ok {
return []string{}
} else {
return syscfg.NodeData[v].Value
}
}
func (syscfg *SysSegment) Int(key string) int {
val := syscfg.Get(key)
if val == "" {
return 0
}
res, _ := strconv.Atoi(val)
return res
}
func (syscfg *SysSegment) Int64(key string) int64 {
val := syscfg.Get(key)
if val == "" {
return 0
}
res, _ := strconv.ParseInt(val, 10, 64)
return res
}
func (syscfg *SysSegment) Int32(key string) int32 {
val := syscfg.Get(key)
if val == "" {
return 0
}
res, _ := strconv.ParseInt(val, 10, 32)
return int32(res)
}
func (syscfg *SysSegment) Float64(key string) float64 {
val := syscfg.Get(key)
if val == "" {
return 0
}
res, _ := strconv.ParseFloat(val, 64)
return res
}
func (syscfg *SysSegment) Float32(key string) float32 {
val := syscfg.Get(key)
if val == "" {
return 0
}
res, _ := strconv.ParseFloat(val, 32)
return float32(res)
}
func (syscfg *SysSegment) Bool(key string) bool {
val := syscfg.Get(key)
if val == "" {
return false
}
res, _ := strconv.ParseBool(val)
return res
}
func (syscfg *SysSegment) SetBool(key string, value bool, comment string) error {
res := strconv.FormatBool(value)
return syscfg.Set(key, res, comment)
}
func (syscfg *SysSegment) SetFloat64(key string, prec int, value float64, comment string) error {
res := strconv.FormatFloat(value, 'f', prec, 64)
return syscfg.Set(key, res, comment)
}
func (syscfg *SysSegment) SetFloat32(key string, prec int, value float32, comment string) error {
res := strconv.FormatFloat(float64(value), 'f', prec, 32)
return syscfg.Set(key, res, comment)
}
func (syscfg *SysSegment) SetUint64(key string, value uint64, comment string) error {
res := strconv.FormatUint(value, 10)
return syscfg.Set(key, res, comment)
}
func (syscfg *SysSegment) SetInt64(key string, value int64, comment string) error {
res := strconv.FormatInt(value, 10)
return syscfg.Set(key, res, comment)
}
func (syscfg *SysSegment) SetInt32(key string, value int32, comment string) error {
res := strconv.FormatInt(int64(value), 10)
return syscfg.Set(key, res, comment)
}
func (syscfg *SysSegment) SetInt(key string, value int, comment string) error {
res := strconv.Itoa(value)
return syscfg.Set(key, res, comment)
}
func (syscfg *SysSegment) Set(key, value, comment string) error {
syscfg.lock.Lock()
defer syscfg.lock.Unlock()
if v, ok := syscfg.nodeMap[key]; !ok {
node := new(SysNode)
node.Key = key
node.Value = append(node.Value, value)
node.Comment = comment
syscfg.NodeData = append(syscfg.NodeData, node)
syscfg.nodeMap[key] = syscfg.nodeId
syscfg.nodeId++
return nil
} else {
syscfg.NodeData[v].Value = []string{value}
if comment != "" {
syscfg.NodeData[v].Comment = comment
}
}
return nil
}
func (syscfg *SysSegment) SetAll(key string, value []string, comment string) error {
syscfg.lock.Lock()
defer syscfg.lock.Unlock()
if v, ok := syscfg.nodeMap[key]; !ok {
node := new(SysNode)
node.Key = key
node.Value = value
node.Comment = comment
syscfg.NodeData = append(syscfg.NodeData, node)
syscfg.nodeMap[key] = syscfg.nodeId
syscfg.nodeId++
return nil
} else {
syscfg.NodeData[v].Value = value
if comment != "" {
syscfg.NodeData[v].Comment = comment
}
}
return nil
}
func (syscfg *SysSegment) AddValue(key, value, comment string) error {
syscfg.lock.Lock()
defer syscfg.lock.Unlock()
if v, ok := syscfg.nodeMap[key]; !ok {
node := new(SysNode)
node.Key = key
node.Value = append(node.Value, value)
node.Comment = comment
syscfg.NodeData = append(syscfg.NodeData, node)
syscfg.nodeMap[key] = syscfg.nodeId
syscfg.nodeId++
return nil
} else {
syscfg.NodeData[v].Value = append(syscfg.NodeData[v].Value, value)
if comment != "" {
syscfg.NodeData[v].Comment = comment
}
}
return nil
}
func (syscfg *SysSegment) Delete(key string) error {
syscfg.lock.Lock()
defer syscfg.lock.Unlock()
if v, ok := syscfg.nodeMap[key]; !ok {
return errors.New("Key not exists!")
} else {
if syscfg.NodeData[v].Comment != "" {
cmtSet := false
for j := v - 1; j >= 0; j-- {
if syscfg.NodeData[j] != nil {
syscfg.NodeData[j].Comment += syscfg.NodeData[v].Comment
cmtSet = true
break
}
}
if !cmtSet {
syscfg.Comment += syscfg.NodeData[v].Comment
}
}
syscfg.NodeData[v] = nil
delete(syscfg.nodeMap, key)
}
return nil
}
func (syscfg *SysSegment) DeleteValue(key string, Value string) error {
syscfg.lock.Lock()
defer syscfg.lock.Unlock()
if v, ok := syscfg.nodeMap[key]; !ok {
return errors.New("Key not exists!")
} else {
data := syscfg.NodeData[v].Value
var vals []string
for _, v := range data {
if v != Value {
vals = append(vals, v)
}
}
syscfg.NodeData[v].Value = vals
}
return nil
}

File diff suppressed because it is too large Load Diff

View File

@ -1 +1,28 @@
package sysconf package sysconf
import "strconv"
func (s *Section) Uint64(key string) uint64 {
v, _ := strconv.ParseUint(s.Get(key), 10, 64)
return v
}
func (s *Section) MustInt(key string) int {
return s.Int(key)
}
func (s *Section) MustInt64(key string) int64 {
return s.Int64(key)
}
func (s *Section) MustUint64(key string) uint64 {
return s.Uint64(key)
}
func (s *Section) MustBool(key string) bool {
return s.Bool(key)
}
func (s *Section) MustFloat64(key string) float64 {
return s.Float64(key)
}

View File

@ -21,12 +21,14 @@ const (
TCP_TIME_WAIT TCP_TIME_WAIT
TCP_CLOSE TCP_CLOSE
TCP_CLOSE_WAIT TCP_CLOSE_WAIT
TCP_LAST_ACL TCP_LAST_ACK
TCP_LISTEN TCP_LISTEN
TCP_CLOSING TCP_CLOSING
) )
var TCP_STATE = []string{"TCP_UNKNOWN", "TCP_ESTABLISHED", "TCP_SYN_SENT", "TCP_SYN_RECV", "TCP_FIN_WAIT1", "TCP_FIN_WAIT2", "TCP_TIME_WAIT", "TCP_CLOSE", "TCP_CLOSE_WAIT", "TCP_LAST_ACL", "TCP_LISTEN", "TCP_CLOSING"} const TCP_LAST_ACL = TCP_LAST_ACK
var TCP_STATE = []string{"TCP_UNKNOWN", "TCP_ESTABLISHED", "TCP_SYN_SENT", "TCP_SYN_RECV", "TCP_FIN_WAIT1", "TCP_FIN_WAIT2", "TCP_TIME_WAIT", "TCP_CLOSE", "TCP_CLOSE_WAIT", "TCP_LAST_ACK", "TCP_LISTEN", "TCP_CLOSING"}
type NetAdapter struct { type NetAdapter struct {
Name string Name string

12
typed_test.go Normal file
View File

@ -0,0 +1,12 @@
package staros
import "testing"
func TestTCPStateLastACKSpellingAndCompatibilityAlias(t *testing.T) {
if TCP_STATE[TCP_LAST_ACK] != "TCP_LAST_ACK" {
t.Fatalf("unexpected LAST_ACK state string: %s", TCP_STATE[TCP_LAST_ACK])
}
if TCP_LAST_ACL != TCP_LAST_ACK {
t.Fatalf("TCP_LAST_ACL should remain a compatibility alias")
}
}