From d93a851d1b146bc39ed72deb57a5a65eb2b44a06 Mon Sep 17 00:00:00 2001 From: starainrt Date: Tue, 9 Jun 2026 18:10:19 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=AE=8C=E5=96=84=20staros=20=E7=B3=BB?= =?UTF-8?q?=E7=BB=9F=E8=83=BD=E5=8A=9B=E5=B9=B6=E6=9B=B4=E6=96=B0=20wincmd?= =?UTF-8?q?=20=E5=8F=91=E5=B8=83=E7=89=88=E4=BE=9D=E8=B5=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 重构 sysconf 为文档模型 INI Parser 与 Config Framework - 强化 hosts 解析、插入校验、写回与异常输入处理 - 完善 StarCmd 生命周期、等待 API、流式输出与 IO 重定向 - 扩展跨平台文件时间、文件锁、内存、进程与网络能力 - 将 Windows 进程适配更新到 b612.me/wincmd v0.1.0 - 移除本地 wincmd/win32api replace,改用发布版依赖 - 将最低 Go 版本提升到 1.18 - 补充 hosts、sysconf、FileLock、StarCmd 与平台适配回归测试 --- README.md | 119 ++ files.go | 10 +- files_darwin.go | 127 ++- files_test.go | 138 ++- files_unix.go | 135 ++- files_windows.go | 112 +- go.mod | 13 +- go.sum | 12 +- hosts/hosts.go | 743 ++++++++----- hosts/hosts_test.go | 582 +++++++++- math.go | 602 +++++----- math_test.go | 61 + memory_darwin.go | 5 +- memory_unix.go | 50 +- memory_windows.go | 8 +- network_darwin.go | 30 + network_test.go | 58 + network_unix.go | 128 ++- network_windows.go | 297 ++++- network_windows_test.go | 121 ++ os.go | 46 +- os_darwin.go | 69 ++ os_test.go | 105 +- os_unix.go | 125 ++- os_unix_test.go | 36 + os_windows.go | 34 +- process.go | 899 ++++++++++++--- process_darwin.go | 73 ++ process_linux_test.go | 110 ++ process_test.go | 820 +++++++++++++- process_unix.go | 406 ++++--- process_win.go | 131 ++- sysconf/config.go | 1716 +++++++++++++++++++++++++++++ sysconf/csv.go | 294 ++--- sysconf/csv_test.go | 38 - sysconf/csvconvert.go | 120 -- sysconf/document.go | 1192 ++++++++++++++++++++ sysconf/example_migration_test.go | 27 + sysconf/ini.go | 445 ++++++++ sysconf/sysconf.go | 855 -------------- sysconf/sysconf_test.go | 1097 +++++++++++++++++- sysconf/typed.go | 27 + typed.go | 6 +- typed_test.go | 12 + 44 files changed, 9774 insertions(+), 2260 deletions(-) create mode 100644 README.md create mode 100644 math_test.go create mode 100644 network_darwin.go create mode 100644 network_windows_test.go create mode 100644 os_darwin.go create mode 100644 os_unix_test.go create mode 100644 process_darwin.go create mode 100644 process_linux_test.go create mode 100644 sysconf/config.go delete mode 100644 sysconf/csv_test.go delete mode 100644 sysconf/csvconvert.go create mode 100644 sysconf/document.go create mode 100644 sysconf/example_migration_test.go create mode 100644 sysconf/ini.go delete mode 100644 sysconf/sysconf.go create mode 100644 typed_test.go diff --git a/README.md b/README.md new file mode 100644 index 0000000..e93a6e5 --- /dev/null +++ b/README.md @@ -0,0 +1,119 @@ +# staros + +`staros` is a cgo-free Go package for small cross-platform OS utilities. + +The package keeps compatibility with existing APIs, but platform-dependent functions should prefer explicit error-returning variants where available. Unsupported platform implementations return `ERR_UNSUPPORTED` instead of silently pretending to work. + +## Go Version + +`go.mod` declares `go 1.18`. This release targets Go 1.18 or newer and no longer promises Go 1.16/1.17 compatibility. The release gate includes the current local Go toolchain plus cross-platform compile checks. + +## Platform Support + +| Area | Linux | Windows | Darwin | +| --- | --- | --- | --- | +| Basic path checks: `Exists`, `IsFile`, `IsFolder` | Supported | Supported | Supported | +| File locks: `FileLock` | `flock` | Win32 lock API | `flock` | +| File timestamps: `GetFileCreationTime`, `GetFileAccessTime`, `SetFileTimesE` | ctime/access/modtime via `x/sys/unix` | creation/access/modtime via `x/sys/windows` | birth/access/modtime via stdlib/syscall | +| Memory: `Memory` | `/proc` and `syscall.Sysinfo` | Win32 `GlobalMemoryStatusEx` | `vm_stat` and `sysctl` | +| Disk: `DiskUsageE` | `Statfs` | `GetDiskFreeSpaceExW` | `Statfs` | +| OS identity: `IsRoot`, `Whoami` | Supported | `IsRoot` supported, `Whoami` returns `ERR_UNSUPPORTED` | Supported | +| CPU usage: `CpuUsage`, `CpuUsageByPid` | `/proc/stat` and `/proc//stat` | Stub returns `0` | Stub returns `0` | +| Process query: `FindProcess*` | `/proc` backed | `wincmd` backed basic fields | Returns `ERR_UNSUPPORTED` | +| Process launch: `Command`, `CommandContext`, `Start` | Supported | Supported | Supported | +| Process lifecycle: `ReleaseE` | Starts through normal lifecycle with `Setsid` | Starts through normal lifecycle | Returns `ERR_UNSUPPORTED` | +| Process detach: `DetachE` | Start-before-Start only, then `Process.Release` | Start-before-Start only, then `Process.Release` | Returns `ERR_UNSUPPORTED` | +| Run as user: `SetRunUserE`, `DaemonWithUser` | Supported | Returns `ERR_UNSUPPORTED` | Returns `ERR_UNSUPPORTED` | +| Keep capabilities: `SetKeepCaps`, `StarCmd.SetKeepCaps` | Package helper uses Linux `prctl`; command helper preserves current caps via `AmbientCaps` | Returns `ERR_UNSUPPORTED` | Returns `ERR_UNSUPPORTED` | +| Network adapters/speeds/connections | `/proc/net` backed | IP Helper adapter counters and TCP/UDP owner-pid tables; no Unix socket/inode fields | Returns `ERR_UNSUPPORTED` | +| Beep | PC speaker or terminal bell fallback | Win32 `Beep` | `osascript` or terminal bell fallback | + +## API Semantics + +- `ERR_UNSUPPORTED` means the symbol exists for build compatibility but the current OS implementation is intentionally unavailable. +- `ReleaseE` preserves `StarCmd` lifecycle observation: `Stopped()` still closes after the process exits and `ExitCode()` is populated when available. The historical misspelled `Stoped()` method remains as a deprecated compatibility alias. +- `DetachE` is a true detached start path and must be called before `Start`; calling it after `Start` returns an already-started error to avoid racing `Process.Release()` with the internal `Wait()`. +- `Wait`, `WaitContext`, and `WaitTimeout` provide explicit lifecycle wait helpers that return the final process wait error. `Stopped()` remains available when callers only need a close signal. +- On Linux, the package-level `SetKeepCaps()` helper applies `prctl(PR_SET_KEEPCAPS)` to the current process. `StarCmd.SetKeepCaps()` is different: it snapshots the current capability set and configures the child command's `SysProcAttr.AmbientCaps` before `Start`. +- `StdoutChan`, `StderrChan`, and `OutputChan` provide best-effort streaming observation for future output chunks. They close with `Stopped()` and do not replace the existing full-output capture methods. +- `RedirectStdout`, `RedirectStderr`, `RedirectOutput`, and `RedirectStdin` configure process IO before `Start`. File helpers such as `RedirectStdoutFile` and `RedirectStdinFile` open the file and close it after the process reaches its final state. +- `WriteStdinE` and `WriteStdinStringE` write raw stdin data without appending a newline. `WriteStdinLineE` and the legacy `WriteCmdE` append one newline. +- Legacy methods that do not return errors are kept for compatibility; prefer the `*E` variants for new code. + +## sysconf Migration + +The `sysconf` package now uses a document-backed INI model instead of the old `SysConf` struct shape. This is a breaking API change: callers should migrate to `NewIni`, `NewLinuxConf`, `Document`, `Section`, and `Entry` directly instead of relying on a field-compatible wrapper. + +Minimal migration rules: + +- Replace parser setup through `SysConf` fields with constructors: use `sysconf.NewIni()` for sectioned INI files, or `sysconf.NewLinuxConf(equal)` for flat Linux-style config files. +- Replace direct segment/key data mutation with section methods: `ini.Section(name)`, `ini.Set(section, key, value)`, `sec.Set`, `sec.SetAll`, `sec.AddValue`, `sec.Delete`, and `ini.DeleteSection`. +- Replace single-value assumptions with duplicate-aware reads where needed: `ini.Get`/`sec.Get` returns the first value, while `ini.GetAll`/`sec.GetAll` returns all repeated keys. +- Replace manual struct binding with `ini.Unmarshal(&dst)` and `ini.Marshal(src)` using `seg` and `key` tags. +- Use `ini.Build()` or `ini.Save(path)` for write-back. Unchanged parsed lines keep their original formatting, while changed entries are rebuilt from the new model. + +Example: + +```go +ini := sysconf.NewIni() +if err := ini.Parse(data); err != nil { + return err +} + +app := ini.Section("app") +if app == nil { + app = ini.AddSection("app") +} +_ = app.SetInt("port", 9090, "") +_ = app.SetAll("feature", []string{"stable", "audit"}, "") + +out := ini.Build() +``` + +INI parser capabilities: + +- `NewIni()` parses common sectioned INI files with `=` and `:` key/value delimiters, `#` and `;` comments, section header comments, quoted values, no-value keys, duplicate keys, duplicate sections, and backslash line continuation. +- Inline comments require whitespace before the comment marker, so values such as URLs or fragments containing `#` are not truncated accidentally. +- `NewIniWithProfiles(...)`, `StrictINIProfile()`, and `LinuxConfProfile(equal)` let callers pin parser behavior for strict sectioned INI or flat Linux-style config files without mutating parser fields ad hoc. +- `Document.Strict` can be enabled when callers want malformed input to return a `ParseError` with line and column information instead of preserving unknown lines as raw content. +- Write-back is lossless for unchanged parsed lines. Changed values that would otherwise be misread as comments, leading/trailing whitespace, tabs, or newlines are emitted as quoted values. + +Config framework capabilities: + +- `sysconf.NewConfig()`, `sysconf.LoadConfig(&dst, files, ...)`, and `sysconf.LoadConfigSources(&dst, sources, ...)` load one or more INI sources in order; later sources override earlier values for the same section/key. +- `sysconf.RequiredFile(path)` and `sysconf.OptionalFile(path)` declare whether a missing file should fail loading or be skipped. `sysconf.BytesSource(name, data)` and `sysconf.StringSource(name, data)` support in-memory overlays for tests, generated defaults, and embedded configs. +- `Config` exposes direct access and write-back helpers: `Get`, `GetAll`, error-returning typed getters such as `GetIntE` / `GetBoolE` / `GetDurationE`, `Has`, `Set`, `SetAll`, `Delete`, `Build`, `Save`, and `SaveAtomic`. +- Struct binding uses `seg`, `key`, `default`, `env`, `split`, and `required` tags. Nested structs inherit their parent `seg` tag, while `env:"-"` disables environment overrides for a field. +- Environment overrides are opt-in through `WithEnvPrefix` / `WithEnvLookup`; generated names normalize section and key names to uppercase underscore form, such as `APP_SERVER_PORT`. +- Binding supports strings, bools, signed/unsigned integers, floats, `time.Duration`, `encoding.TextUnmarshaler`, slices, arrays, and `map[string]T` for scalar `T`. +- Repeated INI keys bind naturally to slices. When a single value should expand into multiple collection items, add an explicit `split` tag such as `split:","`, `split:"|"`, or `split:"csv"`. +- `Config.SetStruct(src)` writes a config struct back into the current document, using repeated keys for slices/arrays and sorted `key=value` repeated entries for maps. +- `sysconf.DescribeConfig(src)` exports struct tag metadata as `ConfigFieldInfo` records, and `sysconf.SampleConfig(src)` builds a sample INI from defaults, current struct values, required placeholders, or type zero values without mutating the source struct. +- `Config.SectionNames()`, `Config.Keys(section)`, `Config.Flatten()`, and `Config.FlattenEntries()` expose sorted section/key discovery, duplicate-aware flattened values, and structured section/key/value entries for diagnostics, tests, and lightweight config export. +- `ConfigError` and `ConfigSourceError` include structured metadata and unwrap their underlying parse, file, or conversion errors for `errors.Is` / `errors.As`. +- A config struct can implement `Validate() error`; validation runs after defaults, file values, env overrides, required checks, and type binding. + +Example: + +```go +type AppConfig struct { + App struct { + Name string `key:"name" required:"true"` + Port int `key:"port" default:"8080"` + Timeout time.Duration `key:"timeout" default:"5s"` + Tags []string `key:"tag" env:"APP_TAGS"` + } `seg:"app"` +} + +var cfg AppConfig +_, err := sysconf.LoadConfigSources(&cfg, []sysconf.ConfigSource{ + sysconf.RequiredFile("/etc/app.ini"), + sysconf.OptionalFile("/etc/app.local.ini"), +}, sysconf.WithEnvPrefix("APP")) +``` + +The current framework is intentionally local-file focused. It does not include hot reload, remote configuration centers, secret managers, or a schema DSL. + +## Scope + +This package is intentionally small and cgo-free. Functionality that duplicates broad system inventory packages should stay frozen unless it improves cross-platform semantics, error observability, or compatibility for existing callers. diff --git a/files.go b/files.go index ba0b51b..d1c03a7 100644 --- a/files.go +++ b/files.go @@ -7,6 +7,12 @@ import ( var ERR_ALREADY_LOCKED = errors.New("ALREADY LOCKED") var ERR_TIMEOUT = errors.New("TIME OUT") +var ERR_UNSUPPORTED = errors.New("UNSUPPORTED") + +var errNilFile = errors.New("nil file") +var errNilFileInfo = errors.New("nil file info") +var errUnsupportedFileInfo = errors.New("unsupported file info") +var errFileLockNotLocked = errors.New("file lock is not locked") func NewFileLock(filepath string) FileLock { return FileLock{ @@ -24,7 +30,7 @@ func Exists(path string) bool { } // IsFile 返回给定文件地址是否是一个文件, -//True为是一个文件,False为不是文件或路径无效 +// True为是一个文件,False为不是文件或路径无效 func IsFile(fpath string) bool { s, err := os.Stat(fpath) if err != nil { @@ -34,7 +40,7 @@ func IsFile(fpath string) bool { } // IsFolder 返回给定文件地址是否是一个文件夹, -//True为是一个文件夹,False为不是文件夹或路径无效 +// True为是一个文件夹,False为不是文件夹或路径无效 func IsFolder(fpath string) bool { s, err := os.Stat(fpath) if err != nil { diff --git a/files_darwin.go b/files_darwin.go index 919dcee..64d2be0 100644 --- a/files_darwin.go +++ b/files_darwin.go @@ -1,9 +1,9 @@ -//+build darwin +//go:build darwin +// +build darwin package staros import ( - "b612.me/stario" "os" "syscall" "time" @@ -12,6 +12,7 @@ import ( type FileLock struct { fd int filepath string + locked bool } func (f *FileLock) openFileForLock() error { @@ -19,7 +20,6 @@ func (f *FileLock) openFileForLock() error { if err != nil { return err } - f.filepath = f.filepath f.fd = fd return nil } @@ -31,10 +31,7 @@ func (f *FileLock) Lock(Exclusive bool) error { } else { lockType = syscall.LOCK_SH } - if err := f.openFileForLock(); err != nil { - return err - } - return syscall.Flock(f.fd, lockType) + return f.lockWithFlags(lockType) } func (f *FileLock) LockNoBlocking(Exclusive bool) error { @@ -44,38 +41,78 @@ func (f *FileLock) LockNoBlocking(Exclusive bool) error { } else { lockType = syscall.LOCK_SH } + return f.lockWithFlags(lockType | syscall.LOCK_NB) +} + +func (f *FileLock) lockWithFlags(lockType int) error { + if f.locked { + return ERR_ALREADY_LOCKED + } if err := f.openFileForLock(); err != nil { return err } - err := syscall.Flock(f.fd, lockType|syscall.LOCK_NB) + err := syscall.Flock(f.fd, lockType) if err != nil { - syscall.Close(f.fd) + _ = syscall.Close(f.fd) + f.fd = 0 if err == syscall.EWOULDBLOCK { return ERR_ALREADY_LOCKED } + return err } - return err + f.locked = true + return nil } func (f *FileLock) Unlock() error { + if f == nil || !f.locked { + return errFileLockNotLocked + } err := syscall.Flock(f.fd, syscall.LOCK_UN) if err != nil { return err } - return syscall.Close(f.fd) + if err := syscall.Close(f.fd); err != nil { + return err + } + f.fd = 0 + f.locked = false + return nil } func (f *FileLock) LockWithTimeout(tm time.Duration, Exclusive bool) error { - return stario.WaitUntilTimeout(tm, func(tmout chan struct{}) error { - err := f.Lock(Exclusive) - select { - case <-tmout: - f.Unlock() + if f.locked { + return ERR_ALREADY_LOCKED + } + var lockType int + if Exclusive { + lockType = syscall.LOCK_EX + } else { + lockType = syscall.LOCK_SH + } + if tm < 0 { + return f.Lock(Exclusive) + } + deadline := time.Now().Add(tm) + for { + err := f.lockWithFlags(lockType | syscall.LOCK_NB) + if err == nil { return nil - default: } - return err - }) + if err != ERR_ALREADY_LOCKED { + return err + } + if !time.Now().Before(deadline) { + return ERR_TIMEOUT + } + sleep := time.Millisecond * 10 + if remaining := time.Until(deadline); remaining < sleep { + sleep = remaining + } + if sleep > 0 { + time.Sleep(sleep) + } + } } func timespecToTime(ts syscall.Timespec) time.Time { @@ -83,9 +120,57 @@ func timespecToTime(ts syscall.Timespec) time.Time { } func GetFileCreationTime(fileinfo os.FileInfo) time.Time { - return timespecToTime(fileinfo.Sys().(*syscall.Stat_t).Ctimespec) + if fileinfo == nil { + return time.Time{} + } + if stat, ok := fileinfo.Sys().(*syscall.Stat_t); ok && stat != nil { + return timespecToTime(stat.Birthtimespec) + } + return time.Time{} } func GetFileAccessTime(fileinfo os.FileInfo) time.Time { - return timespecToTime(fileinfo.Sys().(*syscall.Stat_t).Atimespec) + if fileinfo == nil { + return time.Time{} + } + if stat, ok := fileinfo.Sys().(*syscall.Stat_t); ok && stat != nil { + return timespecToTime(stat.Atimespec) + } + return time.Time{} +} + +func SetFileTimes(file *os.File, info os.FileInfo) { + _ = SetFileTimesE(file, info) +} + +func SetFileTimesbyTime(file *os.File) { + _ = SetFileTimesbyTimeE(file) +} + +func SetFileTimesE(file *os.File, info os.FileInfo) error { + if file == nil { + return errNilFile + } + if info == nil { + return errNilFileInfo + } + stat, ok := info.Sys().(*syscall.Stat_t) + if !ok || stat == nil { + return errUnsupportedFileInfo + } + atime := timespecToTime(stat.Atimespec) + mtime := info.ModTime() + return os.Chtimes(file.Name(), atime, mtime) +} + +func SetFileTimesByTimeE(file *os.File) error { + return SetFileTimesbyTimeE(file) +} + +func SetFileTimesbyTimeE(file *os.File) error { + if file == nil { + return errNilFile + } + now := time.Now() + return os.Chtimes(file.Name(), now, now) } diff --git a/files_test.go b/files_test.go index 7d7637c..e925d51 100644 --- a/files_test.go +++ b/files_test.go @@ -1,23 +1,139 @@ package staros import ( - "fmt" + "errors" "os" + "path/filepath" + "runtime" "testing" "time" ) func Test_FileLock(t *testing.T) { - filename := "./test.file" + filename := filepath.Join(t.TempDir(), "test.file") lock := NewFileLock(filename) lock2 := NewFileLock(filename) - fmt.Println("lock1", lock.LockNoBlocking(false)) - time.Sleep(time.Second) - fmt.Println("lock2", lock2.LockWithTimeout(time.Second*5, false)) - fmt.Println("unlock1", lock.Unlock()) - time.Sleep(time.Second) - fmt.Println("unlock2", lock2.Unlock()) - fmt.Println("lock2", lock2.LockNoBlocking(true)) - fmt.Println("unlock2", lock2.Unlock()) - os.Remove(filename) + if err := lock.LockNoBlocking(false); err != nil { + t.Fatal(err) + } + if err := lock2.LockNoBlocking(false); err != nil { + t.Fatal(err) + } + if err := lock.Unlock(); err != nil { + t.Fatal(err) + } + if err := lock2.Unlock(); err != nil { + t.Fatal(err) + } + if err := lock2.LockNoBlocking(true); err != nil { + t.Fatal(err) + } + if err := lock2.Unlock(); err != nil { + t.Fatal(err) + } + _ = os.Remove(filename) +} + +func TestFileLockExclusiveConflictTimeout(t *testing.T) { + filename := filepath.Join(t.TempDir(), "timeout.file") + lock := NewFileLock(filename) + contender := NewFileLock(filename) + if err := lock.Lock(true); err != nil { + t.Fatal(err) + } + defer lock.Unlock() + + if err := contender.LockNoBlocking(true); !errors.Is(err, ERR_ALREADY_LOCKED) { + if err == nil { + _ = contender.Unlock() + } + t.Fatalf("expected non-blocking exclusive lock conflict, got %v", err) + } + + start := time.Now() + if err := contender.LockWithTimeout(50*time.Millisecond, true); !errors.Is(err, ERR_TIMEOUT) { + if err == nil { + _ = contender.Unlock() + } + t.Fatalf("expected exclusive lock timeout, got %v", err) + } + if elapsed := time.Since(start); elapsed > time.Second { + t.Fatalf("lock timeout took too long: %s", elapsed) + } + + if err := lock.Unlock(); err != nil { + t.Fatal(err) + } + if err := contender.LockWithTimeout(time.Second, true); err != nil { + t.Fatalf("expected lock after owner unlock, got %v", err) + } + if err := contender.Unlock(); err != nil { + t.Fatal(err) + } +} + +func TestFileLockUnlockWithoutSuccessfulLock(t *testing.T) { + filename := filepath.Join(t.TempDir(), "unlock-state.file") + lock := NewFileLock(filename) + if err := lock.Unlock(); !errors.Is(err, errFileLockNotLocked) { + t.Fatalf("expected unlock without lock error, got %v", err) + } + + owner := NewFileLock(filename) + contender := NewFileLock(filename) + if err := owner.Lock(true); err != nil { + t.Fatal(err) + } + defer owner.Unlock() + if err := contender.LockNoBlocking(true); !errors.Is(err, ERR_ALREADY_LOCKED) { + if err == nil { + _ = contender.Unlock() + } + t.Fatalf("expected lock conflict, got %v", err) + } + if err := contender.Unlock(); !errors.Is(err, errFileLockNotLocked) { + t.Fatalf("failed lock attempt should not be unlockable, got %v", err) + } + if err := owner.Unlock(); err != nil { + t.Fatal(err) + } +} + +func TestFileLockRejectsSecondLockOnSameObject(t *testing.T) { + filename := filepath.Join(t.TempDir(), "same-object.file") + lock := NewFileLock(filename) + if err := lock.Lock(true); err != nil { + t.Fatal(err) + } + defer lock.Unlock() + + for name, fn := range map[string]func() error{ + "Lock": func() error { return lock.Lock(true) }, + "LockNoBlocking": func() error { return lock.LockNoBlocking(true) }, + "LockWithTimeout": func() error { return lock.LockWithTimeout(time.Second, true) }, + } { + if err := fn(); !errors.Is(err, ERR_ALREADY_LOCKED) { + if err == nil { + _ = lock.Unlock() + } + t.Fatalf("expected %s on same lock object to reject second lock, got %v", name, err) + } + } +} + +func TestGetFileCreationTimeLinuxUnavailable(t *testing.T) { + if runtime.GOOS != "linux" { + t.Skip("linux-only creation time fallback") + } + filename := filepath.Join(t.TempDir(), "creation-time.file") + if err := os.WriteFile(filename, []byte("demo"), 0o644); err != nil { + t.Fatal(err) + } + info, err := os.Stat(filename) + if err != nil { + t.Fatal(err) + } + if got := GetFileCreationTime(info); !got.IsZero() { + t.Fatalf("linux FileInfo should not report synthetic creation time, got %s", got) + } } diff --git a/files_unix.go b/files_unix.go index 132abcd..b27707e 100644 --- a/files_unix.go +++ b/files_unix.go @@ -1,9 +1,10 @@ -//+build linux +//go:build linux +// +build linux package staros import ( - "b612.me/stario" + "golang.org/x/sys/unix" "os" "syscall" "time" @@ -12,6 +13,7 @@ import ( type FileLock struct { fd int filepath string + locked bool } func timespecToTime(ts syscall.Timespec) time.Time { @@ -19,11 +21,66 @@ func timespecToTime(ts syscall.Timespec) time.Time { } func GetFileCreationTime(fileinfo os.FileInfo) time.Time { - return timespecToTime(fileinfo.Sys().(*syscall.Stat_t).Ctim) + if fileinfo == nil { + return time.Time{} + } + // Linux os.FileInfo/syscall.Stat_t does not expose a stable birth time. + // Returning ctime here would be wrong because it tracks inode changes. + return time.Time{} } func GetFileAccessTime(fileinfo os.FileInfo) time.Time { - return timespecToTime(fileinfo.Sys().(*syscall.Stat_t).Atim) + if fileinfo == nil { + return time.Time{} + } + if stat, ok := fileinfo.Sys().(*syscall.Stat_t); ok && stat != nil { + return timespecToTime(stat.Atim) + } + return time.Time{} +} + +func SetFileTimes(file *os.File, info os.FileInfo) { + _ = SetFileTimesE(file, info) +} + +func SetFileTimesbyTime(file *os.File) { + _ = SetFileTimesbyTimeE(file) +} + +func SetFileTimesE(file *os.File, info os.FileInfo) error { + if file == nil { + return errNilFile + } + if info == nil { + return errNilFileInfo + } + stat, ok := info.Sys().(*syscall.Stat_t) + if !ok || stat == nil { + return errUnsupportedFileInfo + } + atime := timespecToTime(stat.Atim) + mtime := info.ModTime() + return setFileTimes(file.Name(), atime, mtime) +} + +func SetFileTimesByTimeE(file *os.File) error { + return SetFileTimesbyTimeE(file) +} + +func SetFileTimesbyTimeE(file *os.File) error { + if file == nil { + return errNilFile + } + now := time.Now() + return setFileTimes(file.Name(), now, now) +} + +func setFileTimes(path string, atime, mtime time.Time) error { + ts := [2]unix.Timespec{ + unix.NsecToTimespec(atime.UnixNano()), + unix.NsecToTimespec(mtime.UnixNano()), + } + return unix.UtimesNanoAt(unix.AT_FDCWD, path, ts[:], unix.AT_SYMLINK_NOFOLLOW) } func (f *FileLock) openFileForLock() error { @@ -31,7 +88,6 @@ func (f *FileLock) openFileForLock() error { if err != nil { return err } - f.filepath = f.filepath f.fd = fd return nil } @@ -43,10 +99,7 @@ func (f *FileLock) Lock(Exclusive bool) error { } else { lockType = syscall.LOCK_SH } - if err := f.openFileForLock(); err != nil { - return err - } - return syscall.Flock(f.fd, lockType) + return f.lockWithFlags(lockType) } func (f *FileLock) LockNoBlocking(Exclusive bool) error { @@ -56,36 +109,76 @@ func (f *FileLock) LockNoBlocking(Exclusive bool) error { } else { lockType = syscall.LOCK_SH } + return f.lockWithFlags(lockType | syscall.LOCK_NB) +} + +func (f *FileLock) lockWithFlags(lockType int) error { + if f.locked { + return ERR_ALREADY_LOCKED + } if err := f.openFileForLock(); err != nil { return err } - err := syscall.Flock(f.fd, lockType|syscall.LOCK_NB) + err := syscall.Flock(f.fd, lockType) if err != nil { - syscall.Close(f.fd) + _ = syscall.Close(f.fd) + f.fd = 0 if err == syscall.EWOULDBLOCK { return ERR_ALREADY_LOCKED } + return err } - return err + f.locked = true + return nil } func (f *FileLock) Unlock() error { + if f == nil || !f.locked { + return errFileLockNotLocked + } err := syscall.Flock(f.fd, syscall.LOCK_UN) if err != nil { return err } - return syscall.Close(f.fd) + if err := syscall.Close(f.fd); err != nil { + return err + } + f.fd = 0 + f.locked = false + return nil } func (f *FileLock) LockWithTimeout(tm time.Duration, Exclusive bool) error { - return stario.WaitUntilTimeout(tm, func(tmout chan struct{}) error { - err := f.Lock(Exclusive) - select { - case <-tmout: - f.Unlock() + if f.locked { + return ERR_ALREADY_LOCKED + } + var lockType int + if Exclusive { + lockType = syscall.LOCK_EX + } else { + lockType = syscall.LOCK_SH + } + if tm < 0 { + return f.Lock(Exclusive) + } + deadline := time.Now().Add(tm) + for { + err := f.lockWithFlags(lockType | syscall.LOCK_NB) + if err == nil { return nil - default: } - return err - }) + if err != ERR_ALREADY_LOCKED { + return err + } + if !time.Now().Before(deadline) { + return ERR_TIMEOUT + } + sleep := time.Millisecond * 10 + if remaining := time.Until(deadline); remaining < sleep { + sleep = remaining + } + if sleep > 0 { + time.Sleep(sleep) + } + } } diff --git a/files_windows.go b/files_windows.go index 6ca87d9..ef8caac 100644 --- a/files_windows.go +++ b/files_windows.go @@ -1,9 +1,11 @@ +//go:build windows // +build windows package staros import ( "b612.me/win32api" + "golang.org/x/sys/windows" "os" "syscall" "time" @@ -12,23 +14,89 @@ import ( type FileLock struct { filepath string handle win32api.HANDLE + locked bool } func GetFileCreationTime(fileinfo os.FileInfo) time.Time { - d := fileinfo.Sys().(*syscall.Win32FileAttributeData) - return time.Unix(0, d.CreationTime.Nanoseconds()) + if fileinfo == nil { + return time.Time{} + } + if data, ok := fileinfo.Sys().(*syscall.Win32FileAttributeData); ok && data != nil { + return time.Unix(0, data.CreationTime.Nanoseconds()) + } + return time.Time{} } func GetFileAccessTime(fileinfo os.FileInfo) time.Time { - d := fileinfo.Sys().(*syscall.Win32FileAttributeData) - return time.Unix(0, d.LastAccessTime.Nanoseconds()) + if fileinfo == nil { + return time.Time{} + } + if data, ok := fileinfo.Sys().(*syscall.Win32FileAttributeData); ok && data != nil { + return time.Unix(0, data.LastAccessTime.Nanoseconds()) + } + return time.Time{} } func SetFileTimes(file *os.File, info os.FileInfo) { - + _ = SetFileTimesE(file, info) } func SetFileTimesbyTime(file *os.File) { + _ = SetFileTimesbyTimeE(file) +} + +func SetFileTimesE(file *os.File, info os.FileInfo) error { + if file == nil { + return errNilFile + } + if info == nil { + return errNilFileInfo + } + data, ok := info.Sys().(*syscall.Win32FileAttributeData) + if !ok || data == nil { + return errUnsupportedFileInfo + } + ctime := time.Unix(0, data.CreationTime.Nanoseconds()) + atime := time.Unix(0, data.LastAccessTime.Nanoseconds()) + mtime := info.ModTime() + return setFileTimes(file.Name(), ctime, atime, mtime) +} + +func SetFileTimesByTimeE(file *os.File) error { + return SetFileTimesbyTimeE(file) +} + +func SetFileTimesbyTimeE(file *os.File) error { + if file == nil { + return errNilFile + } + now := time.Now() + return setFileTimes(file.Name(), now, now, now) +} + +func setFileTimes(path string, ctime, atime, mtime time.Time) error { + path16, err := windows.UTF16PtrFromString(path) + if err != nil { + return err + } + handle, err := windows.CreateFile( + path16, + windows.FILE_WRITE_ATTRIBUTES, + windows.FILE_SHARE_READ|windows.FILE_SHARE_WRITE|windows.FILE_SHARE_DELETE, + nil, + windows.OPEN_EXISTING, + windows.FILE_FLAG_BACKUP_SEMANTICS, + 0, + ) + if err != nil { + return err + } + defer windows.CloseHandle(handle) + + ctimeFt := windows.NsecToFiletime(ctime.UnixNano()) + atimeFt := windows.NsecToFiletime(atime.UnixNano()) + mtimeFt := windows.NsecToFiletime(mtime.UnixNano()) + return windows.SetFileTime(handle, &ctimeFt, &atimeFt, &mtimeFt) } @@ -53,21 +121,27 @@ func (f *FileLock) openFileForLock() error { } func (f *FileLock) lockForTimeout(timeout time.Duration, lockType win32api.DWORD) error { + if f.locked { + return ERR_ALREADY_LOCKED + } var err error if err = f.openFileForLock(); err != nil { return err } event, err := win32api.CreateEventW(nil, true, false, nil) if err != nil { + _ = f.closeHandle() return err } myEvent := &syscall.Overlapped{HEvent: syscall.Handle(event)} defer syscall.CloseHandle(myEvent.HEvent) _, err = win32api.LockFileEx(f.handle, lockType, 0, 1, 0, myEvent) if err == nil { + f.locked = true return nil } if err != syscall.ERROR_IO_PENDING { + _ = f.closeHandle() return err } millis := uint32(syscall.INFINITE) @@ -78,12 +152,13 @@ func (f *FileLock) lockForTimeout(timeout time.Duration, lockType win32api.DWORD switch s { case syscall.WAIT_OBJECT_0: // success! + f.locked = true return nil case syscall.WAIT_TIMEOUT: - f.Unlock() + _ = f.closeHandle() return ERR_TIMEOUT default: - f.Unlock() + _ = f.closeHandle() return err } } @@ -95,7 +170,7 @@ func (f *FileLock) Lock(Exclusive bool) error { } else { lockType = 0 } - return f.lockForTimeout(0, lockType) + return f.lockForTimeout(-1, lockType) } func (f *FileLock) LockWithTimeout(tm time.Duration, Exclusive bool) error { @@ -119,5 +194,24 @@ func (f *FileLock) LockNoBlocking(Exclusive bool) error { } func (f *FileLock) Unlock() error { - return syscall.Close(syscall.Handle(f.handle)) + if f == nil || !f.locked { + return errFileLockNotLocked + } + if err := f.closeHandle(); err != nil { + return err + } + f.locked = false + return nil +} + +func (f *FileLock) closeHandle() error { + if f == nil || f.handle == 0 { + return nil + } + err := syscall.Close(syscall.Handle(f.handle)) + if err != nil { + return err + } + f.handle = 0 + return nil } diff --git a/go.mod b/go.mod index 52b9067..6b44049 100644 --- a/go.mod +++ b/go.mod @@ -1,10 +1,15 @@ module b612.me/staros -go 1.16 +go 1.18 require ( - b612.me/stario v0.0.10 - b612.me/win32api v0.0.2 - b612.me/wincmd v0.0.4 + b612.me/win32api v0.0.4 + b612.me/wincmd v0.1.0 golang.org/x/sys v0.24.0 ) + +require ( + b612.me/stario v0.0.11 // indirect + golang.org/x/crypto v0.26.0 // indirect + golang.org/x/term v0.23.0 // indirect +) diff --git a/go.sum b/go.sum index e36a9c7..d20738f 100644 --- a/go.sum +++ b/go.sum @@ -1,9 +1,9 @@ -b612.me/stario v0.0.10 h1:+cIyiDCBCjUfodMJDp4FLs+2E1jo7YENkN+sMEe6550= -b612.me/stario v0.0.10/go.mod h1:1Owmu9jzKWgs4VsmeI8YWlGwLrCwPNM/bYpxkyn+MMk= -b612.me/win32api v0.0.2 h1:5PwvPR5fYs3a/v+LjYdtRif+5Q04zRGLTVxmCYNjCpA= -b612.me/win32api v0.0.2/go.mod h1:sj66sFJDKElEjOR+0YhdSW6b4kq4jsXu4T5/Hnpyot0= -b612.me/wincmd v0.0.4 h1:fv9p1V8mw2HdUjaoZBWZy0T41JftueyLxAuch1MgtdI= -b612.me/wincmd v0.0.4/go.mod h1:o3yPoE+DpVPHGKl/q1WT1C8OaIVwHEnpeNgMFqzlwD8= +b612.me/stario v0.0.11 h1:H5SN5G36ZlW7Lu5co3CWK59eHVJduqHSa9a29Cx5ExQ= +b612.me/stario v0.0.11/go.mod h1:1Owmu9jzKWgs4VsmeI8YWlGwLrCwPNM/bYpxkyn+MMk= +b612.me/win32api v0.0.4 h1:V3LgCTbl8UF0Tb1UJDXl8+F/404yLA0XtC/131KmQ7c= +b612.me/win32api v0.0.4/go.mod h1:sj66sFJDKElEjOR+0YhdSW6b4kq4jsXu4T5/Hnpyot0= +b612.me/wincmd v0.1.0 h1:hLOvoIvsPhesb7XbN0l+pcfu4YNWog7YYw11MAkOiDs= +b612.me/wincmd v0.1.0/go.mod h1:kSUbCBCBciJQZi8V2gP78ZEtt8yUHaLatl/5X+V+4Fc= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= diff --git a/hosts/hosts.go b/hosts/hosts.go index 411d56e..b7640dc 100644 --- a/hosts/hosts.go +++ b/hosts/hosts.go @@ -188,6 +188,79 @@ func NewHosts() *Host { } } +func cloneHostNode(node *HostNode) *HostNode { + if node == nil { + return nil + } + cloned := *node + cloned.host = append([]string(nil), node.host...) + return &cloned +} + +func cloneHostNodes(nodes []*HostNode) []*HostNode { + if len(nodes) == 0 { + return nil + } + cloned := make([]*HostNode, 0, len(nodes)) + for _, node := range nodes { + cloned = append(cloned, cloneHostNode(node)) + } + return cloned +} + +func copyHostNodeState(dst, src *HostNode) { + if dst == nil || src == nil { + return + } + dst.uid = src.uid + dst.nextuid = src.nextuid + dst.lastuid = src.lastuid + dst.ip = src.ip + dst.host = append(dst.host[:0], src.host...) + dst.comment = src.comment + dst.original = src.original + dst.onlyComment = src.onlyComment + dst.valid = src.valid +} + +func normalizeEditableNode(node *HostNode) (*HostNode, error) { + if node == nil { + return nil, fmt.Errorf("node not exists") + } + candidate := cloneHostNode(node) + if candidate.comment != "" && !strings.HasPrefix(strings.TrimSpace(candidate.comment), "#") { + candidate.comment = "#" + candidate.comment + } + if len(candidate.host) > 0 { + hosts, err := normalizeHostTokens(candidate.host...) + if err != nil { + return nil, err + } + candidate.host = hosts + } + if candidate.ip != "" || len(candidate.host) > 0 { + ip, err := normalizeIPToken(candidate.ip) + if err != nil { + return nil, err + } + candidate.ip = ip + if len(candidate.host) == 0 { + return nil, fmt.Errorf("empty host") + } + } else { + candidate.ip = strings.TrimSpace(candidate.ip) + } + candidate.onlyComment = candidate.ip == "" && len(candidate.host) == 0 && candidate.comment != "" + if !candidate.onlyComment && candidate.ip == "" && len(candidate.host) == 0 { + return nil, fmt.Errorf("empty node") + } + candidate.valid = candidate.CheckValid() + if !candidate.valid { + return nil, fmt.Errorf("invalid hosts node") + } + return candidate, nil +} + func (h *Host) Parse(hostPath string) error { h.Lock() defer h.Unlock() @@ -210,41 +283,27 @@ func (h *Host) parse() error { buf := bufio.NewReader(f) for { line, err := buf.ReadString('\n') - if err == io.EOF { - if h.idx-1 >= 0 { - h.fulldata[h.idx].nextuid = 0 - h.lastUid = h.idx - } - break - } - if err != nil { + if err != nil && err != io.EOF { return fmt.Errorf("read hosts file error: %s", err) } - h.idx++ - line = strings.TrimSpace(line) - data, _ := h.parseLine(line) - data.uid = h.idx - data.lastuid = h.idx - 1 - data.nextuid = h.idx + 1 - h.fulldata[data.uid] = &data - if h.firstUid == 0 { - h.firstUid = h.idx + raw := strings.TrimRight(line, "\r\n") + if err == nil || (err == io.EOF && raw != "") { + data, _ := h.parseLine(raw) + h.appendNodeLocked(&data) } - if data.valid { - for _, v := range data.host { - h.hostData[v] = append(h.hostData[v], &data) - } - h.ipData[data.ip] = append(h.ipData[data.ip], &data) + if err == io.EOF { + break } } return nil } -func (h *Host) parseLine(data string) (HostNode, error) { +func (h *Host) parseLine(raw string) (HostNode, error) { var res = HostNode{ - original: data, + original: raw, } + data := strings.TrimSpace(raw) if len(data) == 0 { return res, fmt.Errorf("empty line") } @@ -311,7 +370,7 @@ func (h *Host) List() []*HostNode { if nextUid == 0 { break } - res = append(res, h.fulldata[nextUid]) + res = append(res, cloneHostNode(h.fulldata[nextUid])) nextUid = h.fulldata[nextUid].nextuid } return res @@ -353,7 +412,7 @@ func (h *Host) ListByHost(host string) []*HostNode { if h.hostData == nil { return nil } - return h.hostData[host] + return cloneHostNodes(h.hostData[host]) } func (h *Host) ListIPsByHost(host string) []string { @@ -387,7 +446,7 @@ func (h *Host) ListByIP(ip string) []*HostNode { if h.ipData == nil { return nil } - return h.ipData[ip] + return cloneHostNodes(h.ipData[ip]) } func (h *Host) ListHostsByIP(ip string) []string { @@ -443,43 +502,27 @@ cntfor: for _, v := range ipInfo { for _, host := range hosts { if len(v.host) == 1 && v.host[0] == host { - delete(h.ipData, ip) - if v.lastuid != 0 { - h.fulldata[v.lastuid].nextuid = v.nextuid - } else { - h.firstUid = v.nextuid - } - if v.nextuid != 0 { - h.fulldata[v.nextuid].lastuid = v.lastuid - } else { - h.lastUid = v.lastuid - } - var newHostData []*HostNode - for _, vv := range h.hostData[v.host[0]] { - if vv.uid != v.uid { - newHostData = append(newHostData, vv) - } - } - h.hostData[host] = newHostData - delete(h.fulldata, v.uid) + h.removeNodeFromIPDataLocked(ip, v.uid) + h.removeNodeFromHostDataLocked(host, v.uid) + h.unlinkNodeLocked(v) v = nil continue cntfor } if len(v.host) > 1 { var newHosts []string + removed := false for _, vv := range v.host { if vv != host { newHosts = append(newHosts, vv) + } else { + removed = true } } + if !removed { + continue + } v.host = newHosts - var newHostData []*HostNode - for _, vv := range h.hostData[host] { - if vv.uid != v.uid { - newHostData = append(newHostData, vv) - } - } - h.hostData[host] = newHostData + h.removeNodeFromHostDataLocked(host, v.uid) } } } @@ -493,33 +536,7 @@ func (h *Host) RemoveIPs(ips ...string) error { return fmt.Errorf("hosts data not initialized") } for _, ip := range ips { - ipInfo := h.ipData[ip] - if len(ipInfo) == 0 { - continue - } - for _, v := range ipInfo { - delete(h.ipData, ip) - delete(h.fulldata, v.uid) - if v.lastuid != 0 { - h.fulldata[v.lastuid].nextuid = v.nextuid - } else { - h.firstUid = v.nextuid - } - if v.nextuid != 0 { - h.fulldata[v.nextuid].lastuid = v.lastuid - } else { - h.lastUid = v.lastuid - } - for _, host := range v.host { - var newHostData []*HostNode - for _, vv := range h.hostData[host] { - if vv.uid != v.uid { - newHostData = append(newHostData, vv) - } - } - h.hostData[host] = newHostData - } - } + h.removeIPLocked(ip) } return nil } @@ -531,64 +548,67 @@ func (h *Host) RemoveHosts(hosts ...string) error { return fmt.Errorf("hosts data not initialized") } for _, host := range hosts { - hostInfo := h.hostData[host] - if len(hostInfo) == 0 { - continue - } - delete(h.hostData, host) - for _, v := range hostInfo { - var newHosts []string - for _, vv := range v.host { - if vv != host { - newHosts = append(newHosts, vv) - } - } - v.host = newHosts - if len(v.host) == 0 { - delete(h.ipData, v.ip) - if v.lastuid != 0 { - h.fulldata[v.lastuid].nextuid = v.nextuid - } else { - h.firstUid = v.nextuid - } - if v.nextuid != 0 { - h.fulldata[v.nextuid].lastuid = v.lastuid - } else { - h.lastUid = v.lastuid - } - delete(h.fulldata, v.uid) - } - } + h.removeHostLocked(host) } return nil } func (h *Host) SetIPHosts(ip string, hosts ...string) error { - info := h.ListByIP(ip) - if len(info) == 0 { - return h.AddHosts(ip, hosts...) - } else if len(info) == 1 { - info[0].host = hosts - info[0].comment = "" - return nil - } - err := h.RemoveIPs(ip) + var err error + ip, err = normalizeIPToken(ip) if err != nil { return err } - return h.AddHosts(ip, hosts...) + hosts, err = normalizeHostTokens(hosts...) + if err != nil { + return err + } + if len(hosts) == 0 { + return fmt.Errorf("empty host") + } + h.Lock() + defer h.Unlock() + if h.ipData == nil { + return fmt.Errorf("hosts data not initialized") + } + info := h.ipData[ip] + if len(info) == 0 { + return h.addHostsLocked("", ip, hosts...) + } else if len(info) == 1 { + node := *info[0] + node.host = append([]string(nil), hosts...) + node.comment = "" + return h.updateNodeLocked(&node) + } + h.removeIPLocked(ip) + return h.addHostsLocked("", ip, hosts...) } func (h *Host) SetHostIPs(host string, ips ...string) error { if len(ips) == 0 { return fmt.Errorf("no ip address") } - err := h.RemoveHosts(host) + var err error + host, err = normalizeHostToken(host) if err != nil { return err } + normalizedIPs := make([]string, 0, len(ips)) for _, ip := range ips { - err := h.AddHosts(ip, host) + normalizedIP, err := normalizeIPToken(ip) + if err != nil { + return err + } + normalizedIPs = append(normalizedIPs, normalizedIP) + } + h.Lock() + defer h.Unlock() + if h.hostData == nil { + return fmt.Errorf("hosts data not initialized") + } + h.removeHostLocked(host) + for _, ip := range normalizedIPs { + err := h.addHostsLocked("", ip, host) if err != nil { return err } @@ -599,14 +619,27 @@ func (h *Host) SetHostIPs(host string, ips ...string) error { func (h *Host) addHosts(comment string, ip string, hosts ...string) error { h.Lock() defer h.Unlock() + return h.addHostsLocked(comment, ip, hosts...) +} + +func (h *Host) addHostsLocked(comment string, ip string, hosts ...string) error { if h.hostData == nil { return fmt.Errorf("hosts data not initialized") } + var err error + ip, err = normalizeIPToken(ip) + if err != nil { + return err + } ipInfo := h.listHostsByIP(ip) var needAddHosts []string for _, v := range hosts { - if !inArray(ipInfo, v) { - needAddHosts = append(needAddHosts, v) + host, err := normalizeHostToken(v) + if err != nil { + return err + } + if !inArray(ipInfo, host) && !inArray(needAddHosts, host) { + needAddHosts = append(needAddHosts, host) } } if len(needAddHosts) == 0 { @@ -616,39 +649,88 @@ func (h *Host) addHosts(comment string, ip string, hosts ...string) error { comment = "#" + comment } hostNode := HostNode{ - uid: h.idx + 1, - nextuid: 0, - lastuid: h.lastUid, ip: ip, host: needAddHosts, valid: true, comment: comment, } - h.idx++ - h.fulldata[h.lastUid].nextuid = h.idx - h.lastUid = h.idx - h.fulldata[h.idx] = &hostNode - h.ipData[ip] = append(h.ipData[ip], &hostNode) - for _, v := range needAddHosts { - h.hostData[v] = append(h.hostData[v], &hostNode) + if !hostNode.CheckValid() { + return fmt.Errorf("invalid hosts node") } + h.appendNodeLocked(&hostNode) return nil } +func (h *Host) removeHostLocked(host string) { + hostInfo := h.hostData[host] + if len(hostInfo) == 0 { + return + } + delete(h.hostData, host) + for _, v := range hostInfo { + var newHosts []string + for _, vv := range v.host { + if vv != host { + newHosts = append(newHosts, vv) + } + } + v.host = newHosts + if len(v.host) == 0 { + h.removeNodeFromIPDataLocked(v.ip, v.uid) + h.unlinkNodeLocked(v) + } + } +} + +func (h *Host) removeIPLocked(ip string) { + ipInfo := h.ipData[ip] + if len(ipInfo) == 0 { + return + } + delete(h.ipData, ip) + for _, v := range ipInfo { + for _, host := range v.host { + h.removeNodeFromHostDataLocked(host, v.uid) + } + h.unlinkNodeLocked(v) + } +} + func (h *Host) InsertNodeByData(node *HostNode, before bool, comment string, ip string, hosts ...string) error { h.Lock() defer h.Unlock() if h.hostData == nil { return fmt.Errorf("hosts data not initialized") } - if _, ok := h.fulldata[node.uid]; !ok { + if node == nil { + return fmt.Errorf("node not exists") + } + anchor, ok := h.fulldata[node.uid] + if !ok { return fmt.Errorf("node not exists") } if comment != "" && !strings.HasPrefix(strings.TrimSpace(comment), "#") { comment = "#" + comment } + ip = strings.TrimSpace(ip) + var err error + hosts, err = normalizeHostTokens(hosts...) + if err != nil { + return err + } + if ip != "" || len(hosts) > 0 { + ip, err = normalizeIPToken(ip) + if err != nil { + return err + } + if len(hosts) == 0 { + return fmt.Errorf("empty host") + } + } + if ip == "" && len(hosts) == 0 && comment == "" { + return fmt.Errorf("empty node") + } hostNode := HostNode{ - uid: h.idx + 1, ip: ip, host: hosts, valid: true, @@ -657,26 +739,43 @@ func (h *Host) InsertNodeByData(node *HostNode, before bool, comment string, ip if ip == "" && len(hosts) == 0 && comment != "" { hostNode.onlyComment = true } - if before { - hostNode.nextuid = node.uid - hostNode.lastuid = node.lastuid - if node.lastuid != 0 { - h.fulldata[node.lastuid].nextuid = h.idx - } else { - h.firstUid = h.idx - } - node.lastuid = h.idx - } else { - hostNode.lastuid = node.uid - hostNode.nextuid = node.nextuid - if node.nextuid != 0 { - h.fulldata[node.nextuid].lastuid = h.idx - } else { - h.lastUid = h.idx - } - node.nextuid = h.idx + return h.insertNodeRelativeLocked(anchor, before, &hostNode) +} + +func normalizeIPToken(ip string) (string, error) { + ip = strings.TrimSpace(ip) + if net.ParseIP(ip) == nil { + return "", fmt.Errorf("invalid ip address") } - return nil + return ip, nil +} + +func normalizeHostToken(host string) (string, error) { + host = strings.TrimSpace(host) + if host == "" { + return "", fmt.Errorf("empty host") + } + if strings.ContainsAny(host, " \t\r\n") || strings.HasPrefix(host, "#") { + return "", fmt.Errorf("invalid host") + } + return host, nil +} + +func normalizeHostTokens(hosts ...string) ([]string, error) { + out := make([]string, 0, len(hosts)) + seen := make(map[string]struct{}, len(hosts)) + for _, host := range hosts { + normalized, err := normalizeHostToken(host) + if err != nil { + return nil, err + } + if _, ok := seen[normalized]; ok { + continue + } + seen[normalized] = struct{}{} + out = append(out, normalized) + } + return out, nil } func (h *Host) DeleteNode(node *HostNode) error { @@ -685,49 +784,38 @@ func (h *Host) DeleteNode(node *HostNode) error { if h.hostData == nil { return fmt.Errorf("hosts data not initialized") } - if _, ok := h.fulldata[node.uid]; !ok { + if node == nil { return fmt.Errorf("node not exists") } - if node.lastuid != 0 { - h.fulldata[node.lastuid].nextuid = node.nextuid - } else { - h.firstUid = node.nextuid + current, ok := h.fulldata[node.uid] + if !ok { + return fmt.Errorf("node not exists") } - if node.nextuid != 0 { - h.fulldata[node.nextuid].lastuid = node.lastuid - } else { - h.lastUid = node.lastuid + for _, v := range current.host { + h.removeNodeFromHostDataLocked(v, current.uid) } - for _, v := range node.host { - var newHostData []*HostNode - for _, vv := range h.hostData[v] { - if vv.uid != node.uid { - newHostData = append(newHostData, vv) - } - } - h.hostData[v] = newHostData + if current.ip != "" { + h.removeNodeFromIPDataLocked(current.ip, current.uid) } - ipInfo := h.ipData[node.ip] - var newIPData []*HostNode - for _, v := range ipInfo { - if v.uid != node.uid { - newIPData = append(newIPData, v) - } - } - h.ipData[node.ip] = newIPData - delete(h.fulldata, node.uid) + h.unlinkNodeLocked(current) return nil } func (h *Host) DeleteNodeByUID(uid uint64) error { - h.RLock() + h.Lock() + defer h.Unlock() node, ok := h.fulldata[uid] if !ok { - h.RUnlock() return fmt.Errorf("node not exists") } - h.RUnlock() - return h.DeleteNode(node) + for _, v := range node.host { + h.removeNodeFromHostDataLocked(v, node.uid) + } + if node.ip != "" { + h.removeNodeFromIPDataLocked(node.ip, node.uid) + } + h.unlinkNodeLocked(node) + return nil } func (h *Host) AddNode(node *HostNode) error { @@ -736,26 +824,12 @@ func (h *Host) AddNode(node *HostNode) error { if h.hostData == nil { return fmt.Errorf("hosts data not initialized") } - if node.comment != "" && !strings.HasPrefix(strings.TrimSpace(node.comment), "#") { - node.comment = "#" + node.comment - } - if node.ip == "" && len(node.host) == 0 && node.comment != "" { - node.onlyComment = true - } - h.idx++ - node.uid = h.idx - node.lastuid = h.lastUid - node.nextuid = 0 - h.fulldata[h.lastUid].nextuid = node.uid - h.lastUid = node.uid - h.fulldata[node.uid] = node - node.valid = node.CheckValid() - if node.valid { - h.ipData[node.ip] = append(h.ipData[node.ip], node) - for _, v := range node.host { - h.hostData[v] = append(h.hostData[v], node) - } + prepared, err := normalizeEditableNode(node) + if err != nil { + return err } + h.appendNodeLocked(prepared) + copyHostNodeState(node, prepared) return nil } @@ -768,28 +842,157 @@ func (h *Host) InsertNode(node *HostNode) error { if h.hostData == nil { return fmt.Errorf("hosts data not initialized") } - if node.comment != "" && !strings.HasPrefix(strings.TrimSpace(node.comment), "#") { - node.comment = "#" + node.comment + prepared, err := normalizeEditableNode(node) + if err != nil { + return err } - if node.ip == "" && len(node.host) == 0 && node.comment != "" { - node.onlyComment = true + prepared.uid = h.idx + 1 + lastNode := h.fulldata[prepared.lastuid] + nextNode := h.fulldata[prepared.nextuid] + if lastNode == nil || nextNode == nil { + return fmt.Errorf("node link not exists") } - node.uid = h.idx + 1 - if h.fulldata[node.lastuid].nextuid != node.nextuid { + if lastNode.nextuid != prepared.nextuid || nextNode.lastuid != prepared.lastuid { return fmt.Errorf("node lastuid nextuid not match") } h.idx++ - h.fulldata[node.lastuid].nextuid = node.uid - h.fulldata[node.nextuid].lastuid = node.uid - h.fulldata[node.uid] = node + lastNode.nextuid = prepared.uid + nextNode.lastuid = prepared.uid + h.storeNodeLocked(prepared) + copyHostNodeState(node, prepared) + return nil +} + +func (h *Host) appendNodeLocked(node *HostNode) { + if node == nil { + return + } + h.idx++ + node.uid = h.idx + node.lastuid = h.lastUid + node.nextuid = 0 + if h.lastUid != 0 { + if last := h.fulldata[h.lastUid]; last != nil { + last.nextuid = node.uid + } + } else { + h.firstUid = node.uid + } + if h.firstUid == 0 { + h.firstUid = node.uid + } + h.lastUid = node.uid + h.storeNodeLocked(node) +} + +func (h *Host) insertNodeRelativeLocked(anchor *HostNode, before bool, node *HostNode) error { + if anchor == nil || node == nil { + return fmt.Errorf("node not exists") + } + h.idx++ + node.uid = h.idx + if before { + node.nextuid = anchor.uid + node.lastuid = anchor.lastuid + if anchor.lastuid != 0 { + prev := h.fulldata[anchor.lastuid] + if prev == nil { + return fmt.Errorf("node lastuid not exists") + } + prev.nextuid = node.uid + } else { + h.firstUid = node.uid + } + anchor.lastuid = node.uid + } else { + node.lastuid = anchor.uid + node.nextuid = anchor.nextuid + if anchor.nextuid != 0 { + next := h.fulldata[anchor.nextuid] + if next == nil { + return fmt.Errorf("node nextuid not exists") + } + next.lastuid = node.uid + } else { + h.lastUid = node.uid + } + anchor.nextuid = node.uid + } + if h.firstUid == 0 { + h.firstUid = node.uid + } + if node.nextuid == 0 { + h.lastUid = node.uid + } node.valid = node.CheckValid() - if node.valid { - h.ipData[node.ip] = append(h.ipData[node.ip], node) - for _, v := range node.host { - h.hostData[v] = append(h.hostData[v], node) + h.storeNodeLocked(node) + return nil +} + +func (h *Host) storeNodeLocked(node *HostNode) { + if node == nil { + return + } + stored := cloneHostNode(node) + h.fulldata[stored.uid] = stored + if !stored.valid { + return + } + if stored.ip != "" { + h.ipData[stored.ip] = append(h.ipData[stored.ip], stored) + } + for _, v := range stored.host { + h.hostData[v] = append(h.hostData[v], stored) + } +} + +func (h *Host) removeNodeFromIPDataLocked(ip string, uid uint64) { + var newIPData []*HostNode + for _, node := range h.ipData[ip] { + if node.uid != uid { + newIPData = append(newIPData, node) } } - return nil + if len(newIPData) == 0 { + delete(h.ipData, ip) + return + } + h.ipData[ip] = newIPData +} + +func (h *Host) removeNodeFromHostDataLocked(host string, uid uint64) { + var newHostData []*HostNode + for _, node := range h.hostData[host] { + if node.uid != uid { + newHostData = append(newHostData, node) + } + } + if len(newHostData) == 0 { + delete(h.hostData, host) + return + } + h.hostData[host] = newHostData +} + +func (h *Host) unlinkNodeLocked(node *HostNode) { + if node == nil { + return + } + if node.lastuid != 0 { + if last := h.fulldata[node.lastuid]; last != nil { + last.nextuid = node.nextuid + } + } else { + h.firstUid = node.nextuid + } + if node.nextuid != 0 { + if next := h.fulldata[node.nextuid]; next != nil { + next.lastuid = node.lastuid + } + } else { + h.lastUid = node.lastuid + } + delete(h.fulldata, node.uid) } func inArray(arr []string, v string) bool { @@ -812,7 +1015,7 @@ func (h *Host) Save() error { func (h *Host) save(path string) error { h.Lock() defer h.Unlock() - if h.firstUid == 0 || h.fulldata == nil { + if h.fulldata == nil { return fmt.Errorf("no data") } f, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0666) @@ -820,13 +1023,30 @@ func (h *Host) save(path string) error { return fmt.Errorf("open hosts file %s error: %s", h.hostPath, err) } defer f.Close() + if h.firstUid == 0 { + return nil + } node := h.fulldata[h.firstUid] if node == nil { return fmt.Errorf("no data") } for { if node.onlyComment { - if _, err := f.WriteString(node.comment + "\n"); err != nil { + if node.original != "" && strings.TrimSpace(node.original) == node.comment { + if _, err := f.WriteString(node.original + lineBreaker); err != nil { + return fmt.Errorf("write hosts file error: %s", err) + } + } else if _, err := f.WriteString(node.comment + lineBreaker); err != nil { + return fmt.Errorf("write hosts file error: %s", err) + } + if node.nextuid == 0 { + break + } + node = h.fulldata[node.nextuid] + continue + } + if !node.valid { + if _, err := f.WriteString(node.original + lineBreaker); err != nil { return fmt.Errorf("write hosts file error: %s", err) } if node.nextuid == 0 { @@ -858,11 +1078,14 @@ func (h *Host) save(path string) error { } func (h *Host) Build() ([]byte, error) { - if h.firstUid == 0 || h.fulldata == nil { - return nil, fmt.Errorf("no data") - } h.Lock() defer h.Unlock() + if h.fulldata == nil { + return nil, fmt.Errorf("no data") + } + if h.firstUid == 0 { + return []byte{}, nil + } var f bytes.Buffer node := h.fulldata[h.firstUid] if node == nil { @@ -870,7 +1093,21 @@ func (h *Host) Build() ([]byte, error) { } for { if node.onlyComment { - if _, err := f.WriteString(node.comment + "\n"); err != nil { + if node.original != "" && strings.TrimSpace(node.original) == node.comment { + if _, err := f.WriteString(node.original + lineBreaker); err != nil { + return nil, fmt.Errorf("write hosts file error: %s", err) + } + } else if _, err := f.WriteString(node.comment + lineBreaker); err != nil { + return nil, fmt.Errorf("write hosts file error: %s", err) + } + if node.nextuid == 0 { + break + } + node = h.fulldata[node.nextuid] + continue + } + if !node.valid { + if _, err := f.WriteString(node.original + lineBreaker); err != nil { return nil, fmt.Errorf("write hosts file error: %s", err) } if node.nextuid == 0 { @@ -910,7 +1147,7 @@ func (h *Host) GetFirstNode() (*HostNode, error) { if h.fulldata == nil { return nil, fmt.Errorf("no data") } - return h.fulldata[h.firstUid], nil + return cloneHostNode(h.fulldata[h.firstUid]), nil } func (h *Host) GetNode(nodeID uint64) (*HostNode, error) { @@ -920,7 +1157,7 @@ func (h *Host) GetNode(nodeID uint64) (*HostNode, error) { return nil, fmt.Errorf("no data") } if node, ok := h.fulldata[nodeID]; ok { - return node, nil + return cloneHostNode(node), nil } return nil, fmt.Errorf("node not exists") } @@ -935,7 +1172,7 @@ func (h *Host) GetNextNode(nodeID uint64) (*HostNode, error) { if node.nextuid == 0 { return nil, fmt.Errorf("no next node") } - return h.fulldata[node.nextuid], nil + return cloneHostNode(h.fulldata[node.nextuid]), nil } return nil, fmt.Errorf("node not exists") } @@ -950,7 +1187,7 @@ func (h *Host) GetLastNode(nodeID uint64) (*HostNode, error) { if node.lastuid == 0 { return nil, fmt.Errorf("no last node") } - return h.fulldata[node.lastuid], nil + return cloneHostNode(h.fulldata[node.lastuid]), nil } return nil, fmt.Errorf("node not exists") } @@ -964,50 +1201,44 @@ func (h *Host) GetLatestNode() (*HostNode, error) { if h.fulldata == nil { return nil, fmt.Errorf("no data") } - return h.fulldata[h.lastUid], nil + return cloneHostNode(h.fulldata[h.lastUid]), nil } func (h *Host) UpdateNode(node *HostNode) error { h.Lock() defer h.Unlock() + return h.updateNodeLocked(node) +} + +func (h *Host) updateNodeLocked(node *HostNode) error { + if node == nil { + return fmt.Errorf("node not exists") + } if h.fulldata == nil { return fmt.Errorf("no data") } - if _, ok := h.fulldata[node.uid]; !ok { + current, ok := h.fulldata[node.uid] + if !ok { return fmt.Errorf("node not exists") } - if node.comment != "" && !strings.HasPrefix(strings.TrimSpace(node.comment), "#") { - node.comment = "#" + node.comment + candidate, err := normalizeEditableNode(node) + if err != nil { + return err } - if node.ip == "" && len(node.host) == 0 && node.comment != "" { - node.onlyComment = true + candidate.uid = current.uid + candidate.lastuid = current.lastuid + candidate.nextuid = current.nextuid + if candidate.original == "" { + candidate.original = current.original } - h.fulldata[node.uid] = node - node.valid = node.CheckValid() - for k, v := range h.ipData { - var newIPData []*HostNode - for _, vv := range v { - if vv.uid != node.uid { - newIPData = append(newIPData, vv) - } - } - h.ipData[k] = newIPData + for _, host := range current.host { + h.removeNodeFromHostDataLocked(host, current.uid) } - for k, v := range h.hostData { - var newHostData []*HostNode - for _, vv := range v { - if vv.uid != node.uid { - newHostData = append(newHostData, vv) - } - } - h.hostData[k] = newHostData - } - if node.valid { - h.ipData[node.ip] = append(h.ipData[node.ip], node) - for _, v := range node.host { - h.hostData[v] = append(h.hostData[v], node) - } + if current.ip != "" { + h.removeNodeFromIPDataLocked(current.ip, current.uid) } + h.storeNodeLocked(candidate) + copyHostNodeState(node, candidate) return nil } diff --git a/hosts/hosts_test.go b/hosts/hosts_test.go index d356886..b0bca3e 100644 --- a/hosts/hosts_test.go +++ b/hosts/hosts_test.go @@ -2,11 +2,15 @@ package hosts import ( "fmt" + "os" + "path/filepath" + "strings" "testing" ) func Test_Hosts(t *testing.T) { var h = NewHosts() + tmpDir := t.TempDir() err := h.Parse("./test_hosts.txt") if err != nil { t.Error(err) @@ -89,7 +93,7 @@ func Test_Hosts(t *testing.T) { t.Log(data) } - err = h.SaveAs("./test_hosts_01.txt") + err = h.SaveAs(filepath.Join(tmpDir, "test_hosts_01.txt")) if err != nil { t.Error(err) } @@ -133,7 +137,7 @@ func Test_Hosts(t *testing.T) { t.Error("Expected 1 got ", data) } - err = h.SaveAs("./test_hosts_02.txt") + err = h.SaveAs(filepath.Join(tmpDir, "test_hosts_02.txt")) if err != nil { t.Error(err) } @@ -152,3 +156,577 @@ func BenchmarkAddHosts(b *testing.B) { } } } + +func TestParseHandlesEmptyAndNoTrailingNewline(t *testing.T) { + t.Run("empty file", func(t *testing.T) { + h := NewHosts() + path := filepath.Join(t.TempDir(), "hosts.empty") + if err := os.WriteFile(path, nil, 0o644); err != nil { + t.Fatal(err) + } + if err := h.Parse(path); err != nil { + t.Fatal(err) + } + if got := h.List(); len(got) != 0 { + t.Fatalf("expected empty hosts list, got %d entries", len(got)) + } + }) + + t.Run("last line without newline", func(t *testing.T) { + h := NewHosts() + path := filepath.Join(t.TempDir(), "hosts.nonewline") + if err := os.WriteFile(path, []byte("1.2.3.4 example.test"), 0o644); err != nil { + t.Fatal(err) + } + if err := h.Parse(path); err != nil { + t.Fatal(err) + } + if got := h.ListIPsByHost("example.test"); len(got) != 1 || got[0] != "1.2.3.4" { + t.Fatalf("expected last line to be parsed, got %v", got) + } + node, err := h.GetLatestNode() + if err != nil { + t.Fatal(err) + } + if node.NextUID() != 0 { + t.Fatalf("expected last node next uid 0, got %d", node.NextUID()) + } + }) +} + +func TestAddHostsAndAddNodeWorkOnEmptyModel(t *testing.T) { + t.Run("add hosts", func(t *testing.T) { + h := NewHosts() + if err := h.AddHosts("1.2.3.4", "example.test"); err != nil { + t.Fatal(err) + } + if got := h.ListFirstIPByHost("example.test"); got != "1.2.3.4" { + t.Fatalf("expected inserted host ip, got %q", got) + } + out, err := h.Build() + if err != nil { + t.Fatal(err) + } + if !strings.Contains(string(out), "1.2.3.4 example.test") { + t.Fatalf("unexpected build output: %q", out) + } + }) + + t.Run("add node", func(t *testing.T) { + h := NewHosts() + node := &HostNode{} + node.SetIP("5.6.7.8") + node.SetHosts("node.test") + if err := h.AddNode(node); err != nil { + t.Fatal(err) + } + if node.UID() == 0 { + t.Fatal("expected node uid to be assigned") + } + if got := h.ListFirstIPByHost("node.test"); got != "5.6.7.8" { + t.Fatalf("expected inserted node ip, got %q", got) + } + }) +} + +func TestInsertNodeByDataInsertsAndLinksNode(t *testing.T) { + h := NewHosts() + path := filepath.Join(t.TempDir(), "hosts.insert") + if err := os.WriteFile(path, []byte("2.2.2.2 anchor.test\n"), 0o644); err != nil { + t.Fatal(err) + } + if err := h.Parse(path); err != nil { + t.Fatal(err) + } + anchor, err := h.GetFirstNode() + if err != nil { + t.Fatal(err) + } + if err := h.InsertNodeByData(anchor, true, "before", "1.1.1.1", "before.test"); err != nil { + t.Fatal(err) + } + if err := h.InsertNodeByData(anchor, false, "after", "3.3.3.3", "after.test"); err != nil { + t.Fatal(err) + } + + nodes := h.List() + if len(nodes) != 3 { + t.Fatalf("expected 3 nodes after insert, got %d", len(nodes)) + } + if nodes[0].IP() != "1.1.1.1" || nodes[1].IP() != "2.2.2.2" || nodes[2].IP() != "3.3.3.3" { + t.Fatalf("unexpected node order: %q, %q, %q", nodes[0].IP(), nodes[1].IP(), nodes[2].IP()) + } + if got := h.ListFirstIPByHost("before.test"); got != "1.1.1.1" { + t.Fatalf("expected before node to be indexed, got %q", got) + } + if got := h.ListFirstIPByHost("after.test"); got != "3.3.3.3" { + t.Fatalf("expected after node to be indexed, got %q", got) + } + if nodes[0].NextUID() != nodes[1].UID() || nodes[1].LastUID() != nodes[0].UID() { + t.Fatalf("before/anchor linkage broken: before.next=%d anchor.uid=%d anchor.last=%d", nodes[0].NextUID(), nodes[1].UID(), nodes[1].LastUID()) + } + if nodes[1].NextUID() != nodes[2].UID() || nodes[2].LastUID() != nodes[1].UID() { + t.Fatalf("anchor/after linkage broken: anchor.next=%d after.uid=%d after.last=%d", nodes[1].NextUID(), nodes[2].UID(), nodes[2].LastUID()) + } +} + +func TestInsertNodeByDataRejectsNilAnchor(t *testing.T) { + h := NewHosts() + path := filepath.Join(t.TempDir(), "hosts.insert.nil") + if err := os.WriteFile(path, []byte("2.2.2.2 anchor.test\n"), 0o644); err != nil { + t.Fatal(err) + } + if err := h.Parse(path); err != nil { + t.Fatal(err) + } + if err := h.InsertNodeByData(nil, true, "before", "1.1.1.1", "before.test"); err == nil { + t.Fatal("expected nil anchor error") + } +} + +func TestSetIPHostsUpdatesReverseIndex(t *testing.T) { + h := NewHosts() + if err := h.AddHosts("1.2.3.4", "old.test"); err != nil { + t.Fatal(err) + } + if err := h.SetIPHosts("1.2.3.4", "new.test"); err != nil { + t.Fatal(err) + } + if got := h.ListIPsByHost("new.test"); len(got) != 1 || got[0] != "1.2.3.4" { + t.Fatalf("expected new reverse index, got %v", got) + } + if got := h.ListIPsByHost("old.test"); len(got) != 0 { + t.Fatalf("expected old reverse index to be removed, got %v", got) + } +} + +func TestSetIPHostsDeduplicatesHosts(t *testing.T) { + h := NewHosts() + if err := h.AddHosts("1.2.3.4", "old.test"); err != nil { + t.Fatal(err) + } + if err := h.SetIPHosts("1.2.3.4", "new.test", "new.test"); err != nil { + t.Fatal(err) + } + if got := h.ListHostsByIP("1.2.3.4"); len(got) != 1 || got[0] != "new.test" { + t.Fatalf("expected deduplicated ip mapping, got %v", got) + } + if got := h.ListIPsByHost("new.test"); len(got) != 1 || got[0] != "1.2.3.4" { + t.Fatalf("expected deduplicated reverse index, got %v", got) + } +} + +func TestSetIPHostsReplacesMultipleSameIPNodes(t *testing.T) { + h := NewHosts() + if err := h.AddHosts("1.2.3.4", "first.test"); err != nil { + t.Fatal(err) + } + if err := h.AddHosts("1.2.3.4", "second.test"); err != nil { + t.Fatal(err) + } + if err := h.AddHosts("5.6.7.8", "tail.test"); err != nil { + t.Fatal(err) + } + if err := h.SetIPHosts("1.2.3.4", "new.test"); err != nil { + t.Fatal(err) + } + if got := h.ListHostsByIP("1.2.3.4"); len(got) != 1 || got[0] != "new.test" { + t.Fatalf("expected replaced same-ip mappings, got %v", got) + } + if got := h.ListIPsByHost("first.test"); len(got) != 0 { + t.Fatalf("expected first old host to disappear, got %v", got) + } + if got := h.ListIPsByHost("second.test"); len(got) != 0 { + t.Fatalf("expected second old host to disappear, got %v", got) + } + if got := h.ListFirstIPByHost("tail.test"); got != "5.6.7.8" { + t.Fatalf("expected tail node to remain linked, got %q", got) + } + nodes := h.List() + if len(nodes) != 2 || nodes[0].IP() != "5.6.7.8" || nodes[1].IP() != "1.2.3.4" { + t.Fatalf("unexpected node list after SetIPHosts: %#v", nodes) + } + if nodes[0].NextUID() != nodes[1].UID() || nodes[1].LastUID() != nodes[0].UID() { + t.Fatalf("remaining node linkage broken: first.next=%d second.uid=%d second.last=%d", nodes[0].NextUID(), nodes[1].UID(), nodes[1].LastUID()) + } +} + +func TestSetHostIPsReplacesMappingsInOneOperation(t *testing.T) { + h := NewHosts() + if err := h.AddHosts("1.2.3.4", "old.test"); err != nil { + t.Fatal(err) + } + if err := h.SetHostIPs("old.test", "2.2.2.2", "3.3.3.3"); err != nil { + t.Fatal(err) + } + if got := h.ListIPsByHost("old.test"); len(got) != 2 || got[0] != "2.2.2.2" || got[1] != "3.3.3.3" { + t.Fatalf("expected replaced host ip mappings, got %v", got) + } + if got := h.ListHostsByIP("1.2.3.4"); len(got) != 0 { + t.Fatalf("expected old ip mapping to be removed, got %v", got) + } +} + +func TestSetIPHostsRejectsInvalidInputWithoutMutating(t *testing.T) { + tests := []struct { + name string + ip string + hosts []string + }{ + {name: "bad ip", ip: "bad-ip", hosts: []string{"new.test"}}, + {name: "empty host", ip: "1.2.3.4", hosts: []string{""}}, + {name: "comment host", ip: "1.2.3.4", hosts: []string{"#bad.test"}}, + {name: "missing host", ip: "1.2.3.4"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + h := NewHosts() + if err := h.AddHosts("1.2.3.4", "old.test"); err != nil { + t.Fatal(err) + } + if err := h.SetIPHosts(tt.ip, tt.hosts...); err == nil { + t.Fatal("expected invalid SetIPHosts input to fail") + } + if got := h.ListIPsByHost("old.test"); len(got) != 1 || got[0] != "1.2.3.4" { + t.Fatalf("old host mapping should remain after failed SetIPHosts, got %v", got) + } + if got := h.ListHostsByIP("1.2.3.4"); len(got) != 1 || got[0] != "old.test" { + t.Fatalf("ip index should remain after failed SetIPHosts, got %v", got) + } + if got := h.ListIPsByHost("new.test"); len(got) != 0 { + t.Fatalf("failed SetIPHosts should not add new host, got %v", got) + } + }) + } +} + +func TestSetHostIPsRejectsInvalidInputWithoutMutating(t *testing.T) { + tests := []struct { + name string + host string + ips []string + }{ + {name: "empty host", host: "", ips: []string{"2.2.2.2"}}, + {name: "comment host", host: "#old.test", ips: []string{"2.2.2.2"}}, + {name: "bad ip", host: "old.test", ips: []string{"2.2.2.2", "bad-ip"}}, + {name: "missing ip", host: "old.test"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + h := NewHosts() + if err := h.AddHosts("1.2.3.4", "old.test"); err != nil { + t.Fatal(err) + } + if err := h.SetHostIPs(tt.host, tt.ips...); err == nil { + t.Fatal("expected invalid SetHostIPs input to fail") + } + if got := h.ListIPsByHost("old.test"); len(got) != 1 || got[0] != "1.2.3.4" { + t.Fatalf("old host mapping should remain after failed SetHostIPs, got %v", got) + } + if got := h.ListHostsByIP("1.2.3.4"); len(got) != 1 || got[0] != "old.test" { + t.Fatalf("ip index should remain after failed SetHostIPs, got %v", got) + } + if got := h.ListHostsByIP("2.2.2.2"); len(got) != 0 { + t.Fatalf("failed SetHostIPs should not add partial ip mapping, got %v", got) + } + }) + } +} + +func TestRemoveIPHostsKeepsSameIPOtherNodes(t *testing.T) { + h := NewHosts() + if err := h.AddHosts("1.2.3.4", "first.test"); err != nil { + t.Fatal(err) + } + if err := h.AddHosts("1.2.3.4", "second.test"); err != nil { + t.Fatal(err) + } + if err := h.RemoveIPHosts("1.2.3.4", "first.test"); err != nil { + t.Fatal(err) + } + if got := h.ListIPsByHost("first.test"); len(got) != 0 { + t.Fatalf("expected removed host to disappear, got %v", got) + } + if got := h.ListIPsByHost("second.test"); len(got) != 1 || got[0] != "1.2.3.4" { + t.Fatalf("expected same-ip sibling node to stay indexed, got %v", got) + } + if got := h.ListHostsByIP("1.2.3.4"); len(got) != 1 || got[0] != "second.test" { + t.Fatalf("expected ip index to keep sibling host, got %v", got) + } +} + +func TestRemoveHostsKeepsSameIPOtherNodes(t *testing.T) { + h := NewHosts() + if err := h.AddHosts("1.2.3.4", "first.test"); err != nil { + t.Fatal(err) + } + if err := h.AddHosts("1.2.3.4", "second.test"); err != nil { + t.Fatal(err) + } + if err := h.RemoveHosts("first.test"); err != nil { + t.Fatal(err) + } + if got := h.ListHostsByIP("1.2.3.4"); len(got) != 1 || got[0] != "second.test" { + t.Fatalf("expected ip index to keep sibling host after RemoveHosts, got %v", got) + } +} + +func TestRemoveIPsUnlinksAdjacentNodes(t *testing.T) { + h := NewHosts() + if err := h.AddHosts("1.2.3.4", "first.test"); err != nil { + t.Fatal(err) + } + if err := h.AddHosts("1.2.3.4", "second.test"); err != nil { + t.Fatal(err) + } + if err := h.AddHosts("5.6.7.8", "tail.test"); err != nil { + t.Fatal(err) + } + if err := h.RemoveIPs("1.2.3.4"); err != nil { + t.Fatal(err) + } + nodes := h.List() + if len(nodes) != 1 || nodes[0].IP() != "5.6.7.8" || nodes[0].LastUID() != 0 || nodes[0].NextUID() != 0 { + t.Fatalf("expected only tail node with clean links, got %#v", nodes) + } + if got := h.ListHostsByIP("1.2.3.4"); len(got) != 0 { + t.Fatalf("expected removed ip index to be empty, got %v", got) + } + if got := h.ListFirstIPByHost("tail.test"); got != "5.6.7.8" { + t.Fatalf("expected tail reverse index to remain, got %q", got) + } +} + +func TestAddHostsRejectsInvalidInput(t *testing.T) { + h := NewHosts() + if err := h.AddHosts("not-an-ip", "bad.test"); err == nil { + t.Fatal("expected invalid ip error") + } + if got := h.ListHostsByIP("not-an-ip"); len(got) != 0 { + t.Fatalf("invalid ip should not be indexed, got %v", got) + } + if err := h.AddHosts("1.2.3.4", ""); err == nil { + t.Fatal("expected empty host error") + } + if got := h.ListHostsByIP("1.2.3.4"); len(got) != 0 { + t.Fatalf("empty host should not be indexed, got %v", got) + } +} + +func TestInsertNodeByDataRejectsInvalidHostDataWithoutMutating(t *testing.T) { + h := NewHosts() + if err := h.AddHosts("2.2.2.2", "anchor.test"); err != nil { + t.Fatal(err) + } + anchor, err := h.GetFirstNode() + if err != nil { + t.Fatal(err) + } + + tests := []struct { + name string + ip string + hosts []string + }{ + {name: "bad ip", ip: "not-an-ip", hosts: []string{"bad.test"}}, + {name: "empty host", ip: "1.1.1.1", hosts: []string{""}}, + {name: "comment host", ip: "1.1.1.1", hosts: []string{"#bad.test"}}, + {name: "missing host", ip: "1.1.1.1"}, + } + for _, tt := range tests { + if err := h.InsertNodeByData(anchor, false, "", tt.ip, tt.hosts...); err == nil { + t.Fatalf("%s: expected error", tt.name) + } + } + nodes := h.List() + if len(nodes) != 1 || nodes[0].IP() != "2.2.2.2" { + t.Fatalf("invalid insert should not mutate node list: %#v", nodes) + } + if got := h.ListHostsByIP("1.1.1.1"); len(got) != 0 { + t.Fatalf("invalid insert should not mutate ip index: %v", got) + } + + if err := h.InsertNodeByData(anchor, true, "comment-only", ""); err != nil { + t.Fatalf("comment-only insert should remain valid: %v", err) + } + nodes = h.List() + if len(nodes) != 2 || !nodes[0].OnlyComment() || nodes[1].IP() != "2.2.2.2" { + t.Fatalf("comment-only insert mismatch: %#v", nodes) + } +} + +func TestInsertNodeByDataRejectsEmptyNodeWithoutMutating(t *testing.T) { + h := NewHosts() + if err := h.AddHosts("2.2.2.2", "anchor.test"); err != nil { + t.Fatal(err) + } + anchor, err := h.GetFirstNode() + if err != nil { + t.Fatal(err) + } + + if err := h.InsertNodeByData(anchor, true, "", ""); err == nil { + t.Fatal("expected empty insert to fail") + } + nodes := h.List() + if len(nodes) != 1 || nodes[0].IP() != "2.2.2.2" { + t.Fatalf("empty insert should not mutate node list: %#v", nodes) + } + out, err := h.Build() + if err != nil { + t.Fatal(err) + } + if got, want := string(out), "2.2.2.2 anchor.test"+lineBreaker; got != want { + t.Fatalf("empty insert should not change output: got %q want %q", got, want) + } +} + +func TestEmptyHostsBuildAndSaveAs(t *testing.T) { + h := NewHosts() + path := filepath.Join(t.TempDir(), "hosts.empty") + if err := os.WriteFile(path, nil, 0o644); err != nil { + t.Fatal(err) + } + if err := h.Parse(path); err != nil { + t.Fatal(err) + } + out, err := h.Build() + if err != nil { + t.Fatal(err) + } + if len(out) != 0 { + t.Fatalf("expected empty build output, got %q", out) + } + outPath := filepath.Join(t.TempDir(), "hosts.out") + if err := h.SaveAs(outPath); err != nil { + t.Fatal(err) + } + saved, err := os.ReadFile(outPath) + if err != nil { + t.Fatal(err) + } + if len(saved) != 0 { + t.Fatalf("expected empty saved file, got %q", saved) + } +} + +func TestParsePreservesBlankAndRawLines(t *testing.T) { + h := NewHosts() + path := filepath.Join(t.TempDir(), "hosts.raw") + input := []byte("127.0.0.1 localhost\n\nbadline\n# tail comment\n") + if err := os.WriteFile(path, input, 0o644); err != nil { + t.Fatal(err) + } + if err := h.Parse(path); err != nil { + t.Fatal(err) + } + out, err := h.Build() + if err != nil { + t.Fatal(err) + } + got := string(out) + if !strings.Contains(got, "127.0.0.1 localhost"+lineBreaker+lineBreaker+"badline"+lineBreaker) { + t.Fatalf("expected blank/raw lines to be preserved, got %q", got) + } + if !strings.Contains(got, "# tail comment"+lineBreaker) { + t.Fatalf("expected comment line to be preserved, got %q", got) + } +} + +func TestHostAccessorsReturnDetachedCopies(t *testing.T) { + h := NewHosts() + if err := h.AddHosts("1.2.3.4", "example.test"); err != nil { + t.Fatal(err) + } + node, err := h.GetFirstNode() + if err != nil { + t.Fatal(err) + } + node.SetIP("9.9.9.9") + node.SetHosts("mutated.test") + + if got := h.ListFirstIPByHost("example.test"); got != "1.2.3.4" { + t.Fatalf("detached copy mutated internal host index: %q", got) + } + if got := h.ListFirstIPByHost("mutated.test"); got != "" { + t.Fatalf("detached copy should not create new host index: %q", got) + } + out, err := h.Build() + if err != nil { + t.Fatal(err) + } + if strings.Contains(string(out), "9.9.9.9 mutated.test") { + t.Fatalf("detached copy leaked into build output: %q", out) + } +} + +func TestUpdateNodeRejectsInvalidMutationAndPreservesState(t *testing.T) { + h := NewHosts() + if err := h.AddHosts("1.2.3.4", "example.test"); err != nil { + t.Fatal(err) + } + node, err := h.GetFirstNode() + if err != nil { + t.Fatal(err) + } + node.SetIP("bad-ip") + if err := h.UpdateNode(node); err == nil { + t.Fatal("expected invalid update to fail") + } + if got := h.ListFirstIPByHost("example.test"); got != "1.2.3.4" { + t.Fatalf("failed update should preserve previous index, got %q", got) + } + out, err := h.Build() + if err != nil { + t.Fatal(err) + } + if !strings.Contains(string(out), "1.2.3.4 example.test") || strings.Contains(string(out), "bad-ip") { + t.Fatalf("failed update should preserve previous output, got %q", out) + } +} + +func TestUpdateNodeCommentOnlyStateReindexesCleanly(t *testing.T) { + h := NewHosts() + if err := h.AddHosts("1.2.3.4", "example.test"); err != nil { + t.Fatal(err) + } + node, err := h.GetFirstNode() + if err != nil { + t.Fatal(err) + } + node.SetIP("") + node.SetHosts() + node.SetComment("note") + if err := h.UpdateNode(node); err != nil { + t.Fatalf("comment-only update failed: %v", err) + } + if got := h.ListByIP(""); len(got) != 0 { + t.Fatalf("comment-only node should not be indexed under empty ip: %#v", got) + } + updated, err := h.GetNode(node.UID()) + if err != nil { + t.Fatal(err) + } + if !updated.OnlyComment() { + t.Fatalf("comment-only node should keep onlyComment state: %#v", updated) + } + node = updated + node.SetIP("2.2.2.2") + node.SetHosts("restored.test") + if err := h.UpdateNode(node); err != nil { + t.Fatalf("restoring host entry failed: %v", err) + } + updated, err = h.GetNode(node.UID()) + if err != nil { + t.Fatal(err) + } + if updated.OnlyComment() { + t.Fatalf("host entry should clear onlyComment after restore: %#v", updated) + } + if got := h.ListFirstIPByHost("restored.test"); got != "2.2.2.2" { + t.Fatalf("restored host index mismatch: %q", got) + } + if got := h.ListByIP(""); len(got) != 0 { + t.Fatalf("restored host should still avoid empty ip index: %#v", got) + } +} diff --git a/math.go b/math.go index 3643f48..9143860 100644 --- a/math.go +++ b/math.go @@ -1,305 +1,403 @@ package staros import ( - "errors" "fmt" "math" "strconv" "strings" + "unicode" ) -func Calc(math string) (float64, error) { - math = strings.Replace(math, " ", "", -1) - math = strings.ToLower(math) - if err := check(math); err != nil { +// Calc evaluates a small frozen arithmetic expression language kept for +// compatibility with older staros callers. +func Calc(expr string) (float64, error) { + parser := calcParser{input: strings.ToLower(strings.TrimSpace(expr))} + if parser.input == "" { + return 0, fmt.Errorf("empty expression") + } + value, err := parser.parseExpression() + if err != nil { return 0, err } - result,err:=calc(math) - if err!=nil { - return 0,err + parser.skipSpace() + if !parser.done() { + return 0, fmt.Errorf("unexpected token %q at position %d", parser.peek(), parser.pos) } - return floatRound(result,15),nil + return normalizeCalcFloat(value), nil } -func floatRound(f float64, n int) float64 { - format := "%." + strconv.Itoa(n) + "f" - res, _ := strconv.ParseFloat(fmt.Sprintf(format, f), 64) - return res +type calcParser struct { + input string + pos int } -func check(math string) error { - math = strings.Replace(math, " ", "", -1) - math = strings.ToLower(math) - var bracketSum int - var signReady bool - for k, v := range math { - if string([]rune{v}) == "(" { - bracketSum++ - } - if string([]rune{v}) == ")" { - bracketSum-- - } - if bracketSum < 0 { - return fmt.Errorf("err at position %d.Reason is right bracket position not correct,except (", k) - } - if containSign(string([]rune{v})) { - if signReady { - if string([]rune{v}) != "+" && string([]rune{v}) != "-" { - return fmt.Errorf("err at position %d.Reason is sign %s not correct", k, string([]rune{v})) - } - } else { - signReady = true - continue +func (p *calcParser) parseExpression() (float64, error) { + return p.parseAddSub() +} + +func (p *calcParser) parseAddSub() (float64, error) { + left, err := p.parseMulDiv() + if err != nil { + return 0, err + } + for { + p.skipSpace() + switch p.peek() { + case '+': + p.pos++ + right, err := p.parseMulDiv() + if err != nil { + return 0, err } + left += right + case '-': + p.pos++ + right, err := p.parseMulDiv() + if err != nil { + return 0, err + } + left -= right + default: + return left, nil } - signReady = false } - if bracketSum != 0 { - return fmt.Errorf("Error:right bracket is not equal as left bracket") - } - return nil } -func calc(math string) (float64, error) { - var bracketLeft int - var bracketRight int - var DupStart int = -1 - for pos, str := range math { - if string(str) == "(" { - bracketLeft = pos +func (p *calcParser) parseMulDiv() (float64, error) { + left, err := p.parsePower() + if err != nil { + return 0, err + } + for { + p.skipSpace() + switch p.peek() { + case '*': + p.pos++ + right, err := p.parsePower() + if err != nil { + return 0, err + } + left *= right + case '/': + p.pos++ + right, err := p.parsePower() + if err != nil { + return 0, err + } + if right == 0 { + return 0, fmt.Errorf("divisor cannot be 0") + } + left /= right + default: + return left, nil } - if string(str) == ")" { - bracketRight = pos + } +} + +func (p *calcParser) parsePower() (float64, error) { + left, err := p.parseUnary() + if err != nil { + return 0, err + } + p.skipSpace() + if p.peek() != '^' { + return left, nil + } + p.pos++ + right, err := p.parsePower() + if err != nil { + return 0, err + } + return math.Pow(left, right), nil +} + +func (p *calcParser) parseUnary() (float64, error) { + p.skipSpace() + switch p.peek() { + case '+': + p.pos++ + return p.parseUnary() + case '-': + p.pos++ + value, err := p.parseUnary() + if err != nil { + return 0, err + } + return -value, nil + default: + return p.parsePrimary() + } +} + +func (p *calcParser) parsePrimary() (float64, error) { + p.skipSpace() + if p.done() { + return 0, fmt.Errorf("unexpected end of expression") + } + ch := p.peek() + switch { + case ch == '(': + p.pos++ + value, err := p.parseExpression() + if err != nil { + return 0, err + } + p.skipSpace() + if p.peek() != ')' { + return 0, fmt.Errorf("missing ')' at position %d", p.pos) + } + p.pos++ + return value, nil + case isCalcNumberStart(p.input, p.pos): + return p.parseNumber() + case isCalcIdentStart(ch): + return p.parseIdentifier() + default: + return 0, fmt.Errorf("unexpected token %q at position %d", ch, p.pos) + } +} + +func (p *calcParser) parseNumber() (float64, error) { + start := p.pos + seenDot := false + seenExp := false + for !p.done() { + ch := p.peek() + switch { + case ch >= '0' && ch <= '9': + p.pos++ + case ch == '.' && !seenDot && !seenExp: + seenDot = true + p.pos++ + case (ch == 'e') && !seenExp: + seenExp = true + p.pos++ + if !p.done() && (p.peek() == '+' || p.peek() == '-') { + p.pos++ + } + default: + value, err := strconv.ParseFloat(p.input[start:p.pos], 64) + if err != nil { + return 0, fmt.Errorf("invalid number %q at position %d", p.input[start:p.pos], start) + } + return value, nil + } + } + value, err := strconv.ParseFloat(p.input[start:p.pos], 64) + if err != nil { + return 0, fmt.Errorf("invalid number %q at position %d", p.input[start:p.pos], start) + } + return value, nil +} + +func (p *calcParser) parseIdentifier() (float64, error) { + start := p.pos + for !p.done() && isCalcIdent(p.peek()) { + p.pos++ + } + name := p.input[start:p.pos] + p.skipSpace() + if p.peek() != '(' { + value, ok := calcConstant(name) + if !ok { + return 0, fmt.Errorf("unknown identifier %q at position %d", name, start) + } + return value, nil + } + p.pos++ + args, err := p.parseArguments(name) + if err != nil { + return 0, err + } + return calcFunction(name, args) +} + +func (p *calcParser) parseArguments(name string) ([]float64, error) { + p.skipSpace() + if p.peek() == ')' { + p.pos++ + return nil, nil + } + var args []float64 + for { + arg, err := p.parseExpression() + if err != nil { + return nil, err + } + args = append(args, arg) + p.skipSpace() + if p.peek() != ',' { break } - } - if bracketRight == 0 && bracketLeft != 0 || (bracketLeft > bracketRight) { - return 0, fmt.Errorf("Error:bracket not correct at %d ,except )", bracketLeft) - } - if bracketRight == 0 && bracketLeft == 0 { - return calcLong(math) - } - line := math[bracketLeft+1 : bracketRight] - num, err := calcLong(line) - if err != nil { - return 0, err - } - for i := bracketLeft - 1; i >= 0; i-- { - if !containSign(math[i : i+1]) { - DupStart = i - continue + p.pos++ + p.skipSpace() + if p.peek() == ')' { + return nil, fmt.Errorf("missing argument for function %q", name) } - break } - if DupStart != -1 { - sign := math[DupStart:bracketLeft] - num, err := calcDuaFloat(sign, num) - if err != nil { - return 0, err - } - math = math[:DupStart] + fmt.Sprintf("%.15f", num) + math[bracketRight+1:] - DupStart = -1 - } else { - math = math[:bracketLeft] + fmt.Sprintf("%.15f", num) + math[bracketRight+1:] + if p.peek() != ')' { + return nil, fmt.Errorf("missing ')' after function %q", name) } - return calc(math) + p.pos++ + return args, nil } -func calcLong(str string) (float64, error) { - var sigReady bool = false - var sigApply bool = false - var numPool []float64 - var operPool []string - var numStr string - var oper string - if str[0:1] == "+" || str[0:1] == "-" { - sigReady = true +func (p *calcParser) skipSpace() { + for !p.done() && unicode.IsSpace(rune(p.peek())) { + p.pos++ } - for _, stp := range str { - if sigReady && containSign(string(stp)) { - sigReady = false - sigApply = true - oper = string(stp) - continue - } - if !containSign(string(stp)) { - sigReady = false - numStr = string(append([]rune(numStr), stp)) - continue - } - if !sigReady { - sigReady = true - } - if sigApply { - num, err := calcDua(oper, numStr) - if err != nil { - return 0, err - } - sigApply = false - numPool = append(numPool, num) - } else { - num, err := parseNumbic(numStr) - if err != nil { - return 0, err - } - numPool = append(numPool, num) - } - numStr = "" - operPool = append(operPool, string(stp)) - } - if sigApply { - num, err := calcDua(oper, numStr) - if err != nil { - return 0, err - } - numPool = append(numPool, num) - } else { - num, err := parseNumbic(numStr) - if err != nil { - return 0, err - } - numPool = append(numPool, num) - } - return calcPool(numPool, operPool) } -func calcPool(numPool []float64, operPool []string) (float64, error) { - if len(numPool) == 1 && len(operPool) == 0 { - return numPool[0], nil - } - if len(numPool) < len(operPool) { - return 0, errors.New(("Operate Signal Is too much")) - } - calcFunc := func(k int, v string) (float64, error) { - num, err := calcSigFloat(numPool[k], v, numPool[k+1]) - if err != nil { - return 0, err - } - tmp := append(numPool[:k], num) - numPool = append(tmp, numPool[k+2:]...) - operPool = append(operPool[:k], operPool[k+1:]...) - return calcPool(numPool, operPool) - } - for k, v := range operPool { - if v == "^" { - return calcFunc(k, v) - } - } - for k, v := range operPool { - if v == "*" || v == "/" { - return calcFunc(k, v) - } - } - for k, v := range operPool { - return calcFunc(k, v) - } - return 0, nil +func (p *calcParser) done() bool { + return p.pos >= len(p.input) } -func calcSigFloat(floatA float64, b string, floatC float64) (float64, error) { - switch b { - case "+": - return floatRound(floatA + floatC,15), nil - case "-": - return floatRound(floatA - floatC,15), nil - case "*": - return floatRound(floatA * floatC,15), nil - case "/": - if floatC == 0 { - return 0, errors.New("Divisor cannot be 0") - } - return floatRound(floatA / floatC,15), nil - case "^": - return math.Pow(floatA, floatC), nil +func (p *calcParser) peek() byte { + if p.done() { + return 0 } - return 0, fmt.Errorf("unexpect method:%s", b) + return p.input[p.pos] } -func calcSig(a, b, c string) (float64, error) { - floatA, err := parseNumbic(a) - if err != nil { - return 0, err +func isCalcNumberStart(input string, pos int) bool { + ch := input[pos] + if ch >= '0' && ch <= '9' { + return true } - floatC, err := parseNumbic(c) - if err != nil { - return 0, err - } - return calcSigFloat(floatA, b, floatC) + return ch == '.' && pos+1 < len(input) && input[pos+1] >= '0' && input[pos+1] <= '9' } -func calcDuaFloat(a string, floatB float64) (float64, error) { - switch a { - case "sin": - return math.Sin(floatB), nil - case "cos": - return math.Cos(floatB), nil - case "tan": - return math.Tan(floatB), nil - case "abs": - return math.Abs(floatB), nil - case "arcsin": - return math.Asin(floatB), nil - case "arccos": - return math.Acos(floatB), nil - case "arctan": - return math.Atan(floatB), nil - case "sqrt": - return math.Sqrt(floatB), nil - case "loge": - return math.Log(floatB), nil - case "log10": - return math.Log10(floatB), nil - case "log2": - return math.Log2(floatB), nil - case "floor": - return math.Floor(floatB), nil - case "ceil": - return math.Ceil(floatB), nil - case "round": - return math.Round(floatB), nil - case "trunc": - return math.Trunc(floatB), nil - case "+": - return 0 + floatB, nil - case "-": - return 0 - floatB, nil - } - return 0, fmt.Errorf("unexpect method:%s", a) -} -func calcDua(a, b string) (float64, error) { - floatB, err := parseNumbic(b) - if err != nil { - return 0, err - } - return calcDuaFloat(a, floatB) +func isCalcIdentStart(ch byte) bool { + return ch >= 'a' && ch <= 'z' } -func parseNumbic(str string) (float64, error) { - switch str { +func isCalcIdent(ch byte) bool { + return (ch >= 'a' && ch <= 'z') || (ch >= '0' && ch <= '9') || ch == '_' +} + +func calcConstant(name string) (float64, bool) { + switch name { case "pi": - return float64(math.Pi), nil + return math.Pi, true case "e": - return float64(math.E), nil + return math.E, true default: - return strconv.ParseFloat(str, 64) + return 0, false } } -func containSign(str string) bool { - var sign []string = []string{"+", "-", "*", "/", "^"} - for _, v := range sign { - if str == v { - return true - } +func calcFunction(name string, args []float64) (float64, error) { + argCount := len(args) + if !calcFunctionArgCountValid(name, argCount) { + return 0, fmt.Errorf("function %q accepts %s, got %d", name, calcFunctionArgSpec(name), argCount) + } + switch name { + case "sin": + return math.Sin(args[0]), nil + case "cos": + return math.Cos(args[0]), nil + case "tan": + return math.Tan(args[0]), nil + case "sinh": + return math.Sinh(args[0]), nil + case "cosh": + return math.Cosh(args[0]), nil + case "tanh": + return math.Tanh(args[0]), nil + case "abs": + return math.Abs(args[0]), nil + case "arcsin", "asin": + return math.Asin(args[0]), nil + case "arccos", "acos": + return math.Acos(args[0]), nil + case "arctan", "atan": + return math.Atan(args[0]), nil + case "sqrt": + return math.Sqrt(args[0]), nil + case "cbrt": + return math.Cbrt(args[0]), nil + case "exp": + return math.Exp(args[0]), nil + case "loge", "ln": + return math.Log(args[0]), nil + case "log": + return math.Log10(args[0]), nil + case "log10": + return math.Log10(args[0]), nil + case "log2": + return math.Log2(args[0]), nil + case "floor": + return math.Floor(args[0]), nil + case "ceil": + return math.Ceil(args[0]), nil + case "round": + return math.Round(args[0]), nil + case "trunc": + return math.Trunc(args[0]), nil + case "rad": + return args[0] * math.Pi / 180.0, nil + case "deg": + return args[0] * 180.0 / math.Pi, nil + case "pow": + return math.Pow(args[0], args[1]), nil + case "hypot": + return math.Hypot(args[0], args[1]), nil + case "min": + result := args[0] + for _, arg := range args[1:] { + if arg < result { + result = arg + } + } + return result, nil + case "max": + result := args[0] + for _, arg := range args[1:] { + if arg > result { + result = arg + } + } + return result, nil + default: + return 0, fmt.Errorf("unknown function %q", name) } - return false } -func contain(pool []string, str string) bool { - for _, v := range pool { - if v == str { - return true - } +func calcFunctionArgCountValid(name string, count int) bool { + switch name { + case "pow", "hypot": + return count == 2 + case "min", "max": + return count >= 1 + case "sin", "cos", "tan", "sinh", "cosh", "tanh", + "abs", "arcsin", "asin", "arccos", "acos", "arctan", "atan", + "sqrt", "cbrt", "exp", "loge", "ln", "log", "log10", "log2", + "floor", "ceil", "round", "trunc", "rad", "deg": + return count == 1 + default: + return true } - return false +} + +func calcFunctionArgSpec(name string) string { + switch name { + case "pow", "hypot": + return "exactly two arguments" + case "min", "max": + return "at least one argument" + default: + return "exactly one argument" + } +} + +func normalizeCalcFloat(value float64) float64 { + text := strconv.FormatFloat(value, 'g', 15, 64) + out, err := strconv.ParseFloat(text, 64) + if err != nil { + return value + } + if out == 0 { + return 0 + } + return out } diff --git a/math_test.go b/math_test.go new file mode 100644 index 0000000..a699205 --- /dev/null +++ b/math_test.go @@ -0,0 +1,61 @@ +package staros + +import ( + "math" + "testing" +) + +func TestCalcCompatibilityExpressions(t *testing.T) { + tests := []struct { + expr string + want float64 + }{ + {"1+2*3", 7}, + {"(1+2)*3", 9}, + {"60*60*24", 86400}, + {"-1+2", 1}, + {"sqrt(4)+abs(-3)", 5}, + {"sin(pi/2)", 1}, + {"arcsin(1)", math.Pi / 2}, + {"asin(1)", math.Pi / 2}, + {"loge(e)", 1}, + {"ln(e)", 1}, + {"log10(100)+log2(8)", 5}, + {"floor(1.9)+ceil(1.1)+round(1.5)+trunc(1.9)", 6}, + {"1.2e3+3", 1203}, + {"pow(2,3)+hypot(3,4)", 13}, + {"min(3,1,2)+max(3,1,2)", 4}, + {"log(100)+rad(180)/pi+deg(pi)/180", 4}, + {"cbrt(27)+exp(0)", 4}, + {"sinh(0)+cosh(0)+tanh(0)", 1}, + } + for _, tt := range tests { + got, err := Calc(tt.expr) + if err != nil { + t.Fatalf("Calc(%q) failed: %v", tt.expr, err) + } + if math.Abs(got-tt.want) > 1e-12 { + t.Fatalf("Calc(%q)=%v, want %v", tt.expr, got, tt.want) + } + } +} + +func TestCalcRejectsInvalidExpressions(t *testing.T) { + tests := []string{ + "", + "1/", + "(1+2", + "1/0", + "unknown(1)", + "min()", + "pow(2)", + "pow(2,3,4)", + "sqrt(1,2)", + "pi()", + } + for _, expr := range tests { + if got, err := Calc(expr); err == nil { + t.Fatalf("Calc(%q)=%v, expected error", expr, got) + } + } +} diff --git a/memory_darwin.go b/memory_darwin.go index 9bb85d5..9519baa 100644 --- a/memory_darwin.go +++ b/memory_darwin.go @@ -1,4 +1,5 @@ -//+build darwin +//go:build darwin +// +build darwin package staros @@ -13,7 +14,7 @@ import ( ) // Memory 系统内存信息 -func Memory() (MemStatus,error) { +func Memory() (MemStatus, error) { return darwinMemory() } diff --git a/memory_unix.go b/memory_unix.go index e77a682..9df308f 100644 --- a/memory_unix.go +++ b/memory_unix.go @@ -1,8 +1,14 @@ -//+build linux +//go:build linux +// +build linux package staros -import "syscall" +import ( + "io/ioutil" + "strconv" + "strings" + "syscall" +) // Memory 系统内存信息 func Memory() (MemStatus, error) { @@ -11,14 +17,40 @@ func Memory() (MemStatus, error) { if err := syscall.Sysinfo(ram); err != nil { return mem, err } - mem.All = uint64(ram.Totalram) - mem.BuffCache = uint64(ram.Bufferram) - mem.Free = uint64(ram.Freeram) - mem.Shared = uint64(ram.Sharedram) - mem.Available = uint64(ram.Freeram + ram.Sharedram + ram.Bufferram) - mem.SwapAll = uint64(ram.Totalswap) - mem.SwapFree = uint64(ram.Freeswap) + unit := uint64(ram.Unit) + if unit == 0 { + unit = 1 + } + mem.All = uint64(ram.Totalram) * unit + mem.BuffCache = uint64(ram.Bufferram) * unit + mem.Free = uint64(ram.Freeram) * unit + mem.Shared = uint64(ram.Sharedram) * unit + mem.Available = mem.Free + mem.Shared + mem.BuffCache + if available, ok := linuxMemAvailable(); ok { + mem.Available = available + } + mem.SwapAll = uint64(ram.Totalswap) * unit + mem.SwapFree = uint64(ram.Freeswap) * unit mem.SwapUsed = uint64(mem.SwapAll - mem.SwapFree) mem.Used = uint64(mem.All - mem.Free) return mem, nil } + +func linuxMemAvailable() (uint64, bool) { + data, err := ioutil.ReadFile("/proc/meminfo") + if err != nil { + return 0, false + } + for _, line := range strings.Split(string(data), "\n") { + fields := strings.Fields(line) + if len(fields) < 2 || fields[0] != "MemAvailable:" { + continue + } + value, err := strconv.ParseUint(fields[1], 10, 64) + if err != nil { + return 0, false + } + return value * 1024, true + } + return 0, false +} diff --git a/memory_windows.go b/memory_windows.go index 5995975..315189a 100644 --- a/memory_windows.go +++ b/memory_windows.go @@ -1,3 +1,4 @@ +//go:build windows // +build windows package staros @@ -19,8 +20,9 @@ func Memory() (MemStatus, error) { mem.SwapAll = uint64(ram.UllTotalPageFile) mem.SwapFree = uint64(ram.UllAvailPageFile) mem.SwapUsed = mem.SwapAll - mem.SwapFree - mem.VirtualAll = uint64(mem.VirtualAll) - mem.VirtualAvail = uint64(mem.VirtualAvail) - mem.VirtualUsed = mem.VirtualAll - mem.VirtualUsed + mem.VirtualAll = uint64(ram.UllTotalVirtual) + mem.VirtualAvail = uint64(ram.UllAvailVirtual) + mem.VirtualUsed = mem.VirtualAll - mem.VirtualAvail + mem.AvailExtended = uint64(ram.UllAvailExtendedVirtual) return mem, nil } diff --git a/network_darwin.go b/network_darwin.go new file mode 100644 index 0000000..4f892ca --- /dev/null +++ b/network_darwin.go @@ -0,0 +1,30 @@ +//go:build darwin +// +build darwin + +package staros + +import "time" + +func NetUsage() ([]NetAdapter, error) { + return nil, ERR_UNSUPPORTED +} + +func NetUsageByname(name string) (NetAdapter, error) { + return NetAdapter{}, ERR_UNSUPPORTED +} + +func NetSpeeds(duration time.Duration) ([]NetSpeed, error) { + return nil, ERR_UNSUPPORTED +} + +func NetSpeedsByName(duration time.Duration, name string) (NetSpeed, error) { + return NetSpeed{}, ERR_UNSUPPORTED +} + +func NetConnections(analysePid bool, types string) ([]NetConn, error) { + return nil, ERR_UNSUPPORTED +} + +func GetInodeMap() (map[string]int64, error) { + return nil, ERR_UNSUPPORTED +} diff --git a/network_test.go b/network_test.go index 1b30d7f..8ad8fe3 100644 --- a/network_test.go +++ b/network_test.go @@ -1,9 +1,67 @@ +//go:build linux +// +build linux + package staros import ( "testing" + "time" ) func Test_TrimSpace(t *testing.T) { } + +func TestAnalyseNetFilesSkipsShortLines(t *testing.T) { + data := []byte("sl local_address rem_address st tx_queue rx_queue tr tm->when retrnsmt uid timeout inode\nshort\n") + res, err := analyseNetFiles(data, nil, "tcp") + if err != nil { + t.Fatal(err) + } + if len(res) != 0 { + t.Fatalf("expected no parsed connections, got %d", len(res)) + } +} + +func TestNetSpeedsRejectsInvalidDuration(t *testing.T) { + _, err := NetSpeeds(0) + if err == nil { + t.Fatal("expected invalid duration error") + } + _, err = NetSpeeds(-time.Second) + if err == nil { + t.Fatal("expected invalid duration error") + } +} + +func TestUniqueStrings(t *testing.T) { + got := uniqueStrings([]string{"tcp", "udp", "tcp"}) + if len(got) != 2 { + t.Fatalf("expected 2 unique values, got %d", len(got)) + } + if got[0] != "tcp" || got[1] != "udp" { + t.Fatalf("unexpected order: %#v", got) + } +} + +func TestParseProcStatusKB(t *testing.T) { + if got := parseProcStatusKB("12 kB"); got != 12*1024 { + t.Fatalf("expected 12288, got %d", got) + } + if got := parseProcStatusKB(""); got != 0 { + t.Fatalf("expected 0, got %d", got) + } +} + +func TestProcStartTimeFromStatHandlesProcessNameWithSpacesAndParens(t *testing.T) { + stat := []byte("42 (name with ) parens) S 1 1 1 0 -1 4194560 0 0 0 0 0 0 0 0 20 0 1 0 12345") + got, ok := procStartTimeFromStat(stat) + if !ok { + t.Fatal("expected proc stat start time to parse") + } + ticks := int64(clockTicks()) + want := time.Unix(StartTime().Unix()+12345/ticks, (12345%ticks)*int64(time.Second)/ticks) + if !got.Equal(want) { + t.Fatalf("unexpected start time: got %s want %s", got, want) + } +} diff --git a/network_unix.go b/network_unix.go index 5daefeb..793674e 100644 --- a/network_unix.go +++ b/network_unix.go @@ -1,5 +1,5 @@ -//go:build !windows -// +build !windows +//go:build linux +// +build linux package staros @@ -15,26 +15,37 @@ import ( func NetUsage() ([]NetAdapter, error) { data, err := ioutil.ReadFile("/proc/net/dev") if err != nil { - return []NetAdapter{}, err + return nil, err } sps := strings.Split(strings.TrimSpace(string(data)), "\n") if len(sps) < 3 { - return []NetAdapter{}, errors.New("No Adaptor") + return nil, errors.New("No Adaptor") } var res []NetAdapter netLists := sps[2:] for _, v := range netLists { - v = strings.ReplaceAll(v, " ", " ") - for strings.Contains(v, " ") { - v = strings.ReplaceAll(v, " ", " ") + parts := strings.SplitN(strings.TrimSpace(v), ":", 2) + if len(parts) != 2 { + continue + } + card := strings.Fields(parts[1]) + if len(card) < 16 { + continue + } + name := strings.TrimSpace(parts[0]) + recvBytes, err := strconv.ParseUint(card[0], 10, 64) + if err != nil { + continue + } + sendBytes, err := strconv.ParseUint(card[8], 10, 64) + if err != nil { + continue } - v = strings.TrimSpace(v) - card := strings.Split(v, " ") - name := strings.ReplaceAll(card[0], ":", "") - recvBytes, _ := strconv.Atoi(card[1]) - sendBytes, _ := strconv.Atoi(card[9]) res = append(res, NetAdapter{name, uint64(recvBytes), uint64(sendBytes)}) } + if len(res) == 0 { + return nil, errors.New("No Adaptor") + } return res, nil } @@ -52,30 +63,48 @@ func NetUsageByname(name string) (NetAdapter, error) { } func NetSpeeds(duration time.Duration) ([]NetSpeed, error) { + if duration <= 0 { + return nil, errors.New("duration must be positive") + } list1, err := NetUsage() if err != nil { - return []NetSpeed{}, err + return nil, err } time.Sleep(duration) list2, err := NetUsage() if err != nil { - return []NetSpeed{}, err + return nil, err } - if len(list1) > len(list2) { - return []NetSpeed{}, errors.New("NetWork Adaptor Num Not ok") + byName := make(map[string]NetAdapter, len(list2)) + for _, item := range list2 { + byName[item.Name] = item } var res []NetSpeed - for k, v := range list1 { - recv := float64(list2[k].RecvBytes-v.RecvBytes) / duration.Seconds() - send := float64(list2[k].SendBytes-v.SendBytes) / duration.Seconds() + for _, v := range list1 { + next, ok := byName[v.Name] + if !ok { + continue + } + var recvDelta, sendDelta uint64 + if next.RecvBytes >= v.RecvBytes { + recvDelta = next.RecvBytes - v.RecvBytes + } + if next.SendBytes >= v.SendBytes { + sendDelta = next.SendBytes - v.SendBytes + } + recv := float64(recvDelta) / duration.Seconds() + send := float64(sendDelta) / duration.Seconds() res = append(res, NetSpeed{ Name: v.Name, RecvSpeeds: recv, SendSpeeds: send, - RecvBytes: list2[k].RecvBytes, - SendBytes: list2[k].SendBytes, + RecvBytes: next.RecvBytes, + SendBytes: next.SendBytes, }) } + if len(res) == 0 { + return nil, errors.New("NetWork Adaptor Num Not ok") + } return res, nil } @@ -99,7 +128,8 @@ func NetConnections(analysePid bool, types string) ([]NetConn, error) { var inodeMap map[string]int64 var err error var fileList []string - if types == "" || strings.Contains(strings.ToLower(types), "all") { + types = strings.ToLower(types) + if types == "" || strings.Contains(types, "all") { fileList = []string{ "/proc/net/tcp", "/proc/net/tcp6", @@ -107,25 +137,33 @@ func NetConnections(analysePid bool, types string) ([]NetConn, error) { "/proc/net/udp6", "/proc/net/unix", } + } else { + if strings.Contains(types, "tcp") { + fileList = append(fileList, "/proc/net/tcp", "/proc/net/tcp6") + } + if strings.Contains(types, "udp") { + fileList = append(fileList, "/proc/net/udp", "/proc/net/udp6") + } + if strings.Contains(types, "unix") { + fileList = append(fileList, "/proc/net/unix") + } } - if strings.Contains(strings.ToLower(types), "tcp") { - fileList = append(fileList, "/proc/net/tcp", "/proc/net/tcp6") - } - if strings.Contains(strings.ToLower(types), "udp") { - fileList = append(fileList, "/proc/net/udp", "/proc/net/udp6") - } - if strings.Contains(strings.ToLower(types), "unix") { - fileList = append(fileList, "/proc/net/unix") + fileList = uniqueStrings(fileList) + if len(fileList) == 0 { + return nil, errors.New("unsupported net connection type") } if analysePid { inodeMap, err = GetInodeMap() if err != nil { - return result, err + inodeMap = nil } } for _, file := range fileList { data, err := ioutil.ReadFile(file) if err != nil { + if os.IsNotExist(err) { + continue + } return result, err } tmpRes, err := analyseNetFiles(data, inodeMap, file[strings.LastIndex(file, "/")+1:]) @@ -137,6 +175,22 @@ func NetConnections(analysePid bool, types string) ([]NetConn, error) { return result, nil } +func uniqueStrings(items []string) []string { + if len(items) == 0 { + return nil + } + seen := make(map[string]struct{}, len(items)) + res := make([]string, 0, len(items)) + for _, item := range items { + if _, ok := seen[item]; ok { + continue + } + seen[item] = struct{}{} + res = append(res, item) + } + return res +} + func GetInodeMap() (map[string]int64, error) { res := make(map[string]int64) paths, err := ioutil.ReadDir("/proc") @@ -186,6 +240,9 @@ func analyseNetFiles(data []byte, inodeMap map[string]int64, typed string) ([]Ne continue } v := strings.Split(strings.TrimSpace(lineData), " ") + if len(v) < 10 { + continue + } var res NetConn ip, port, err := parseHexIpPort(v[1]) if err != nil { @@ -205,7 +262,11 @@ func analyseNetFiles(data []byte, inodeMap map[string]int64, typed string) ([]Ne if err != nil { return result, err } - res.Status = TCP_STATE[state] + if state >= 0 && int(state) < len(TCP_STATE) { + res.Status = TCP_STATE[state] + } else { + res.Status = TCP_STATE[TCP_UNKNOWN] + } } txrx_queue := strings.Split(strings.TrimSpace(v[4]), ":") if len(txrx_queue) != 2 { @@ -293,6 +354,9 @@ func analyseUnixFiles(data []byte, inodeMap map[string]int64, typed string) ([]N continue } v := strings.Split(strings.TrimSpace(lineData), " ") + if len(v) < 7 { + continue + } var res NetConn res.Inode = v[6] if len(v) == 8 { diff --git a/network_windows.go b/network_windows.go index 1bf985c..fc3a1c9 100644 --- a/network_windows.go +++ b/network_windows.go @@ -1,33 +1,314 @@ +//go:build windows // +build windows package staros import ( + "errors" + "net" + "strconv" + "strings" + "syscall" "time" + + "b612.me/win32api" ) +const windowsErrorNotSupported syscall.Errno = 50 + func NetUsage() ([]NetAdapter, error) { - var res []NetAdapter + rows, err := win32api.GetIfTable2() + if err != nil { + return nil, err + } + res := make([]NetAdapter, 0, len(rows)) + for _, row := range rows { + name := windowsInterfaceName(row) + if name == "" { + continue + } + res = append(res, NetAdapter{ + Name: name, + RecvBytes: row.InOctets, + SendBytes: row.OutOctets, + }) + } + if len(res) == 0 { + return nil, errors.New("No Adaptor") + } return res, nil } func NetUsageByname(name string) (NetAdapter, error) { - return NetAdapter{}, nil + ada, err := NetUsage() + if err != nil { + return NetAdapter{}, err + } + for _, v := range ada { + if v.Name == name { + return v, nil + } + } + return NetAdapter{}, errors.New("Not Found") } func NetSpeeds(duration time.Duration) ([]NetSpeed, error) { - var res []NetSpeed + if duration <= 0 { + return nil, errors.New("duration must be positive") + } + list1, err := NetUsage() + if err != nil { + return nil, err + } + time.Sleep(duration) + list2, err := NetUsage() + if err != nil { + return nil, err + } + byName := make(map[string]NetAdapter, len(list2)) + for _, item := range list2 { + byName[item.Name] = item + } + res := make([]NetSpeed, 0, len(list1)) + for _, v := range list1 { + next, ok := byName[v.Name] + if !ok { + continue + } + var recvDelta, sendDelta uint64 + if next.RecvBytes >= v.RecvBytes { + recvDelta = next.RecvBytes - v.RecvBytes + } + if next.SendBytes >= v.SendBytes { + sendDelta = next.SendBytes - v.SendBytes + } + res = append(res, NetSpeed{ + Name: v.Name, + RecvSpeeds: float64(recvDelta) / duration.Seconds(), + SendSpeeds: float64(sendDelta) / duration.Seconds(), + RecvBytes: next.RecvBytes, + SendBytes: next.SendBytes, + }) + } + if len(res) == 0 { + return nil, errors.New("NetWork Adaptor Num Not ok") + } return res, nil } func NetSpeedsByName(duration time.Duration, name string) (NetSpeed, error) { - - return NetSpeed{}, nil + ada, err := NetSpeeds(duration) + if err != nil { + return NetSpeed{}, err + } + for _, v := range ada { + if v.Name == name { + return v, nil + } + } + return NetSpeed{}, errors.New("Not Found") } // NetConnections return all TCP/UDP/UNIX DOMAIN SOCKET Connections // if your uid != 0 ,and analysePid==true ,you should have CAP_SYS_PRTACE and CAP_DAC_OVERRIDE/CAP_DAC_READ_SEARCH Caps -func NetConnections(analysePid bool) ([]NetConn, error) { - var result []NetConn +func NetConnections(analysePid bool, types string) ([]NetConn, error) { + wantTCP, wantUDP, err := windowsNetConnectionTypes(types) + if err != nil { + return nil, err + } + + result := make([]NetConn, 0) + processCache := make(map[int64]*Process) + if wantTCP { + result, err = appendWindowsTCPConnections(result, analysePid, processCache) + if err != nil { + return result, err + } + } + if wantUDP { + result, err = appendWindowsUDPConnections(result, analysePid, processCache) + if err != nil { + return result, err + } + } return result, nil -} \ No newline at end of file +} + +func GetInodeMap() (map[string]int64, error) { + return nil, ERR_UNSUPPORTED +} + +func windowsInterfaceName(row win32api.MIB_IF_ROW2) string { + name := strings.TrimSpace(syscall.UTF16ToString(row.Alias[:])) + if name != "" { + return name + } + name = strings.TrimSpace(syscall.UTF16ToString(row.Description[:])) + if name != "" { + return name + } + if row.InterfaceIndex != 0 { + return "if" + strconv.FormatUint(uint64(row.InterfaceIndex), 10) + } + if row.InterfaceLuid != 0 { + return "luid" + strconv.FormatUint(row.InterfaceLuid, 10) + } + return "" +} + +func windowsNetConnectionTypes(types string) (wantTCP, wantUDP bool, err error) { + normalized := strings.ToLower(strings.TrimSpace(types)) + if strings.Contains(normalized, "unix") { + return false, false, ERR_UNSUPPORTED + } + if normalized == "" || strings.Contains(normalized, "all") { + return true, true, nil + } + if strings.Contains(normalized, "tcp") { + wantTCP = true + } + if strings.Contains(normalized, "udp") { + wantUDP = true + } + if !wantTCP && !wantUDP { + return false, false, errors.New("unsupported net connection type") + } + return wantTCP, wantUDP, nil +} + +func appendWindowsTCPConnections(result []NetConn, analysePid bool, processCache map[int64]*Process) ([]NetConn, error) { + rows4, err := win32api.GetExtendedTcp4Table(false, win32api.TCP_TABLE_OWNER_PID_ALL) + if err != nil { + return result, err + } + for _, row := range rows4 { + conn := NetConn{ + LocalAddr: windowsIPv4FromDWORD(row.LocalAddr), + LocalPort: int(row.LocalPortHost()), + RemoteAddr: windowsIPv4FromDWORD(row.RemoteAddr), + RemotePort: int(row.RemotePortHost()), + Status: windowsTCPState(row.State), + Typed: "tcp", + } + attachWindowsProcess(&conn, row.OwningPid, analysePid, processCache) + result = append(result, conn) + } + + rows6, err := win32api.GetExtendedTcp6Table(false, win32api.TCP_TABLE_OWNER_PID_ALL) + if err != nil { + if isOptionalWindowsNetTableError(err) { + return result, nil + } + return result, err + } + for _, row := range rows6 { + conn := NetConn{ + LocalAddr: net.IP(row.LocalAddr[:]).String(), + LocalPort: int(row.LocalPortHost()), + RemoteAddr: net.IP(row.RemoteAddr[:]).String(), + RemotePort: int(row.RemotePortHost()), + Status: windowsTCPState(row.State), + Typed: "tcp6", + } + attachWindowsProcess(&conn, row.OwningPid, analysePid, processCache) + result = append(result, conn) + } + return result, nil +} + +func appendWindowsUDPConnections(result []NetConn, analysePid bool, processCache map[int64]*Process) ([]NetConn, error) { + rows4, err := win32api.GetExtendedUdp4Table(false, win32api.UDP_TABLE_OWNER_PID) + if err != nil { + return result, err + } + for _, row := range rows4 { + conn := NetConn{ + LocalAddr: windowsIPv4FromDWORD(row.LocalAddr), + LocalPort: int(row.LocalPortHost()), + Typed: "udp", + } + attachWindowsProcess(&conn, row.OwningPid, analysePid, processCache) + result = append(result, conn) + } + + rows6, err := win32api.GetExtendedUdp6Table(false, win32api.UDP_TABLE_OWNER_PID) + if err != nil { + if isOptionalWindowsNetTableError(err) { + return result, nil + } + return result, err + } + for _, row := range rows6 { + conn := NetConn{ + LocalAddr: net.IP(row.LocalAddr[:]).String(), + LocalPort: int(row.LocalPortHost()), + Typed: "udp6", + } + attachWindowsProcess(&conn, row.OwningPid, analysePid, processCache) + result = append(result, conn) + } + return result, nil +} + +func attachWindowsProcess(conn *NetConn, pid uint32, analysePid bool, processCache map[int64]*Process) { + if conn == nil || !analysePid { + return + } + conn.Pid = int64(pid) + if conn.Pid <= 0 { + return + } + if proc, ok := processCache[conn.Pid]; ok { + conn.Process = proc + return + } + proc, err := FindProcessByPid(conn.Pid) + if err != nil { + processCache[conn.Pid] = nil + return + } + processCache[conn.Pid] = &proc + conn.Process = &proc +} + +func windowsIPv4FromDWORD(addr uint32) string { + return net.IPv4(byte(addr), byte(addr>>8), byte(addr>>16), byte(addr>>24)).String() +} + +func windowsTCPState(state win32api.MIB_TCP_STATE) string { + switch state { + case win32api.MIB_TCP_STATE_CLOSED: + return TCP_STATE[TCP_CLOSE] + case win32api.MIB_TCP_STATE_LISTEN: + return TCP_STATE[TCP_LISTEN] + case win32api.MIB_TCP_STATE_SYN_SENT: + return TCP_STATE[TCP_SYN_SENT] + case win32api.MIB_TCP_STATE_SYN_RCVD: + return TCP_STATE[TCP_SYN_RECV] + case win32api.MIB_TCP_STATE_ESTAB: + return TCP_STATE[TCP_ESTABLISHED] + case win32api.MIB_TCP_STATE_FIN_WAIT1: + return TCP_STATE[TCP_FIN_WAIT1] + case win32api.MIB_TCP_STATE_FIN_WAIT2: + return TCP_STATE[TCP_FIN_WAIT2] + case win32api.MIB_TCP_STATE_CLOSE_WAIT: + return TCP_STATE[TCP_CLOSE_WAIT] + case win32api.MIB_TCP_STATE_CLOSING: + return TCP_STATE[TCP_CLOSING] + case win32api.MIB_TCP_STATE_LAST_ACK: + return TCP_STATE[TCP_LAST_ACK] + case win32api.MIB_TCP_STATE_TIME_WAIT: + return TCP_STATE[TCP_TIME_WAIT] + case win32api.MIB_TCP_STATE_DELETE_TCB: + return "TCP_DELETE_TCB" + default: + return TCP_STATE[TCP_UNKNOWN] + } +} + +func isOptionalWindowsNetTableError(err error) bool { + if errno, ok := err.(syscall.Errno); ok { + return errno == windowsErrorNotSupported + } + return false +} diff --git a/network_windows_test.go b/network_windows_test.go new file mode 100644 index 0000000..0a410b2 --- /dev/null +++ b/network_windows_test.go @@ -0,0 +1,121 @@ +//go:build windows +// +build windows + +package staros + +import ( + "errors" + "syscall" + "testing" + + "b612.me/win32api" +) + +func TestWindowsNetConnectionTypes(t *testing.T) { + tcp, udp, err := windowsNetConnectionTypes("") + if err != nil { + t.Fatal(err) + } + if !tcp || !udp { + t.Fatalf("empty types should request tcp and udp, got tcp=%v udp=%v", tcp, udp) + } + + tcp, udp, err = windowsNetConnectionTypes("all") + if err != nil { + t.Fatal(err) + } + if !tcp || !udp { + t.Fatalf("all types should request tcp and udp, got tcp=%v udp=%v", tcp, udp) + } + + tcp, udp, err = windowsNetConnectionTypes("tcp") + if err != nil { + t.Fatal(err) + } + if !tcp || udp { + t.Fatalf("tcp types mismatch: tcp=%v udp=%v", tcp, udp) + } + + tcp, udp, err = windowsNetConnectionTypes("TCP,UDP") + if err != nil { + t.Fatal(err) + } + if !tcp || !udp { + t.Fatalf("mixed tcp/udp types mismatch: tcp=%v udp=%v", tcp, udp) + } + + tcp, udp, err = windowsNetConnectionTypes("udp") + if err != nil { + t.Fatal(err) + } + if tcp || !udp { + t.Fatalf("udp types mismatch: tcp=%v udp=%v", tcp, udp) + } + + if _, _, err = windowsNetConnectionTypes("unix"); !errors.Is(err, ERR_UNSUPPORTED) { + t.Fatalf("unix should be unsupported on windows, got %v", err) + } + if _, _, err = windowsNetConnectionTypes("tcp,unix"); !errors.Is(err, ERR_UNSUPPORTED) { + t.Fatalf("mixed unix request should be unsupported on windows, got %v", err) + } + if _, _, err = windowsNetConnectionTypes("all,unix"); !errors.Is(err, ERR_UNSUPPORTED) { + t.Fatalf("all plus unix request should be unsupported on windows, got %v", err) + } + if _, _, err = windowsNetConnectionTypes("raw"); err == nil { + t.Fatal("unknown type should return error") + } +} + +func TestWindowsIPv4FromDWORD(t *testing.T) { + if got := windowsIPv4FromDWORD(0x0100007f); got != "127.0.0.1" { + t.Fatalf("unexpected localhost conversion: %s", got) + } +} + +func TestWindowsTCPState(t *testing.T) { + cases := map[win32api.MIB_TCP_STATE]string{ + win32api.MIB_TCP_STATE_ESTAB: TCP_STATE[TCP_ESTABLISHED], + win32api.MIB_TCP_STATE_LISTEN: TCP_STATE[TCP_LISTEN], + win32api.MIB_TCP_STATE_SYN_SENT: TCP_STATE[TCP_SYN_SENT], + win32api.MIB_TCP_STATE_SYN_RCVD: TCP_STATE[TCP_SYN_RECV], + win32api.MIB_TCP_STATE_FIN_WAIT1: TCP_STATE[TCP_FIN_WAIT1], + win32api.MIB_TCP_STATE_FIN_WAIT2: TCP_STATE[TCP_FIN_WAIT2], + win32api.MIB_TCP_STATE_TIME_WAIT: TCP_STATE[TCP_TIME_WAIT], + win32api.MIB_TCP_STATE_CLOSED: TCP_STATE[TCP_CLOSE], + win32api.MIB_TCP_STATE_CLOSE_WAIT: TCP_STATE[TCP_CLOSE_WAIT], + win32api.MIB_TCP_STATE_LAST_ACK: TCP_STATE[TCP_LAST_ACK], + win32api.MIB_TCP_STATE_CLOSING: TCP_STATE[TCP_CLOSING], + } + for state, want := range cases { + if got := windowsTCPState(state); got != want { + t.Fatalf("state %d mismatch: got=%s want=%s", state, got, want) + } + } + if got := windowsTCPState(win32api.MIB_TCP_STATE(0)); got != TCP_STATE[TCP_UNKNOWN] { + t.Fatalf("unknown state mismatch: %s", got) + } +} + +func TestIsOptionalWindowsNetTableError(t *testing.T) { + if !isOptionalWindowsNetTableError(windowsErrorNotSupported) { + t.Fatal("ERROR_NOT_SUPPORTED should be optional") + } + if isOptionalWindowsNetTableError(syscall.EINVAL) { + t.Fatal("EINVAL should not be optional") + } +} + +func TestAttachWindowsProcess(t *testing.T) { + conn := NetConn{} + cache := map[int64]*Process{} + attachWindowsProcess(&conn, 123, false, cache) + if conn.Pid != 0 || conn.Process != nil { + t.Fatalf("analysePid=false should not populate process fields: %#v", conn) + } + + conn = NetConn{} + attachWindowsProcess(&conn, 0, true, cache) + if conn.Pid != 0 || conn.Process != nil { + t.Fatalf("pid 0 should not populate process fields: %#v", conn) + } +} diff --git a/os.go b/os.go index 49fca98..600cb82 100644 --- a/os.go +++ b/os.go @@ -1,19 +1,34 @@ package staros import ( + "fmt" "os/user" "strconv" ) +func parseUint32Identity(kind, raw string) (uint32, error) { + value, err := strconv.ParseUint(raw, 10, 32) + if err != nil { + return 0, fmt.Errorf("parse %s %q: %w", kind, raw, err) + } + return uint32(value), nil +} + // GetUidGid func GetUidGid(uname string) (uint32, uint32, string, error) { usr, err := user.Lookup(uname) if err != nil { return 0, 0, "", err } - uidInt, _ := strconv.Atoi(usr.Uid) - gidInt, _ := strconv.Atoi(usr.Gid) - return uint32(uidInt), uint32(gidInt), usr.HomeDir, nil + uid, err := parseUint32Identity("uid", usr.Uid) + if err != nil { + return 0, 0, "", err + } + gid, err := parseUint32Identity("gid", usr.Gid) + if err != nil { + return 0, 0, "", err + } + return uid, gid, usr.HomeDir, nil } // GetUid @@ -22,26 +37,23 @@ func GetUid(uname string) (uint32, error) { if err != nil { return 0, err } - uidInt, _ := strconv.Atoi(usr.Uid) - return uint32(uidInt), nil + return parseUint32Identity("uid", usr.Uid) } // GetGid func GetGid(uname string) (uint32, error) { - usr, err := user.LookupGroup(uname) - if err != nil { - return 0, err - } - gidInt, _ := strconv.Atoi(usr.Gid) - return uint32(gidInt), nil -} - -// GetGidByName -func GetGidByName(uname string) (uint32, error) { usr, err := user.Lookup(uname) if err != nil { return 0, err } - uidInt, _ := strconv.Atoi(usr.Gid) - return uint32(uidInt), nil + return parseUint32Identity("gid", usr.Gid) +} + +// GetGidByName +func GetGidByName(uname string) (uint32, error) { + usr, err := user.LookupGroup(uname) + if err != nil { + return 0, err + } + return parseUint32Identity("gid", usr.Gid) } diff --git a/os_darwin.go b/os_darwin.go new file mode 100644 index 0000000..79ec511 --- /dev/null +++ b/os_darwin.go @@ -0,0 +1,69 @@ +//go:build darwin +// +build darwin + +package staros + +import ( + "os/user" + "strconv" + "syscall" + "time" +) + +// StartTime is not implemented on Darwin yet. +func StartTime() time.Time { + return time.Time{} +} + +// IsRoot 当前是否是管理员用户 +func IsRoot() bool { + uid, err := user.Current() + return err == nil && uid.Uid == "0" +} + +func Whoami() (uid, gid int, uname, gname, home string, err error) { + var me *user.User + var group *user.Group + me, err = user.Current() + if err != nil { + return + } + uid, _ = strconv.Atoi(me.Uid) + gid, _ = strconv.Atoi(me.Gid) + home = me.HomeDir + uname = me.Username + group, err = user.LookupGroupId(me.Gid) + if err != nil { + return + } + gname = group.Name + return +} + +func CpuUsageByPid(pid int, sleep time.Duration) float64 { + return 0 +} + +func CpuUsage(sleep time.Duration) float64 { + return 0 +} + +func DiskUsage(path string) (disk DiskStatus) { + disk, _ = DiskUsageE(path) + return +} + +func DiskUsageE(path string) (disk DiskStatus, err error) { + if path == "" { + path = "." + } + fs := syscall.Statfs_t{} + if err = syscall.Statfs(path, &fs); err != nil { + return + } + disk.All = fs.Blocks * uint64(fs.Bsize) + disk.Free = fs.Bfree * uint64(fs.Bsize) + disk.Available = fs.Bavail * uint64(fs.Bsize) + disk.Used = disk.All - disk.Free + return +} diff --git a/os_test.go b/os_test.go index 054dbc9..3876563 100644 --- a/os_test.go +++ b/os_test.go @@ -1,11 +1,110 @@ package staros import ( - "fmt" + "errors" + "os/user" + "strconv" "testing" + "time" ) func Test_Disk(t *testing.T) { - a := DiskUsage("c:") - fmt.Println(a) + disk, err := DiskUsageE(".") + if err != nil { + t.Fatal(err) + } + if disk.All == 0 { + t.Fatal("expected non-zero total disk size") + } + if disk.Used+disk.Free != disk.All { + t.Fatalf("expected used + free == all, got used=%d free=%d all=%d", disk.Used, disk.Free, disk.All) + } +} + +func TestCpuUsageDoesNotPanic(t *testing.T) { + _ = CpuUsage(time.Millisecond) +} + +func TestWhoamiGID(t *testing.T) { + _, gid, _, _, _, err := Whoami() + if errors.Is(err, ERR_UNSUPPORTED) { + t.Skip(err) + } + if err != nil { + t.Fatal(err) + } + current, err := user.Current() + if err != nil { + t.Fatal(err) + } + expected, err := strconv.Atoi(current.Gid) + if err != nil { + t.Fatal(err) + } + if gid != expected { + t.Fatalf("expected gid %d, got %d", expected, gid) + } +} + +func TestIdentityLookupFunctions(t *testing.T) { + current, err := user.Current() + if err != nil { + t.Skipf("user.Current unavailable: %v", err) + } + + wantUID, uidErr := strconv.ParseUint(current.Uid, 10, 32) + wantGID, gidErr := strconv.ParseUint(current.Gid, 10, 32) + + uid, gid, home, err := GetUidGid(current.Username) + if uidErr == nil && gidErr == nil { + if err != nil { + t.Fatalf("GetUidGid failed: %v", err) + } + if uid != uint32(wantUID) || gid != uint32(wantGID) || home != current.HomeDir { + t.Fatalf("GetUidGid mismatch: uid=%d gid=%d home=%q", uid, gid, home) + } + } else if err == nil { + t.Fatalf("GetUidGid should reject non-numeric ids: uid=%q gid=%q", current.Uid, current.Gid) + } + + uid, err = GetUid(current.Username) + if uidErr == nil { + if err != nil { + t.Fatalf("GetUid failed: %v", err) + } + if uid != uint32(wantUID) { + t.Fatalf("GetUid mismatch: got=%d want=%d", uid, wantUID) + } + } else if err == nil { + t.Fatalf("GetUid should reject non-numeric uid %q", current.Uid) + } + + gid, err = GetGid(current.Username) + if gidErr == nil { + if err != nil { + t.Fatalf("GetGid failed: %v", err) + } + if gid != uint32(wantGID) { + t.Fatalf("GetGid mismatch: got=%d want=%d", gid, wantGID) + } + } else if err == nil { + t.Fatalf("GetGid should reject non-numeric gid %q", current.Gid) + } + + group, err := user.LookupGroupId(current.Gid) + if err != nil { + t.Skipf("user.LookupGroupId unavailable: %v", err) + } + groupID, groupErr := strconv.ParseUint(group.Gid, 10, 32) + gotGroupID, err := GetGidByName(group.Name) + if groupErr == nil { + if err != nil { + t.Fatalf("GetGidByName failed: %v", err) + } + if gotGroupID != uint32(groupID) { + t.Fatalf("GetGidByName mismatch: got=%d want=%d", gotGroupID, groupID) + } + } else if err == nil { + t.Fatalf("GetGidByName should reject non-numeric gid %q", group.Gid) + } } diff --git a/os_unix.go b/os_unix.go index e31e453..6216208 100644 --- a/os_unix.go +++ b/os_unix.go @@ -1,19 +1,27 @@ -// +build linux darwin unix +//go:build linux +// +build linux package staros import ( "bytes" - "fmt" + "encoding/binary" + "errors" "io/ioutil" + "os" "os/user" "strconv" "strings" + "sync" "syscall" "time" + "unsafe" ) -var clockTicks = 100 // default value +var ( + clockTicksOnce sync.Once + clockTicksValue uint64 = 100 +) // StartTime 开机时间 func StartTime() time.Time { @@ -25,11 +33,8 @@ func StartTime() time.Time { // IsRoot 当前是否是管理员用户 func IsRoot() bool { - uid, _ := user.Current() - if uid.Uid == "0" { - return true - } - return false + uid, err := user.Current() + return err == nil && uid != nil && uid.Uid == "0" } func Whoami() (uid, gid int, uname, gname, home string, err error) { @@ -40,7 +45,7 @@ func Whoami() (uid, gid int, uname, gname, home string, err error) { return } uid, _ = strconv.Atoi(me.Uid) - gid, _ = strconv.Atoi(me.Uid) + gid, _ = strconv.Atoi(me.Gid) home = me.HomeDir uname = me.Username gup, err = user.LookupGroupId(me.Gid) @@ -51,6 +56,79 @@ func Whoami() (uid, gid int, uname, gname, home string, err error) { return } +func clockTicks() uint64 { + clockTicksOnce.Do(initClockTicks) + if clockTicksValue == 0 { + return 100 + } + return clockTicksValue +} + +func initClockTicks() { + ticks, err := readClockTicksFromAuxv() + if err != nil || ticks == 0 { + return + } + clockTicksValue = ticks +} + +func readClockTicksFromAuxv() (uint64, error) { + data, err := os.ReadFile("/proc/self/auxv") + if err != nil { + return 0, err + } + wordSize := int(unsafe.Sizeof(uintptr(0))) + if wordSize != 4 && wordSize != 8 { + return 0, errors.New("unsupported pointer size") + } + order := nativeEndian() + entrySize := wordSize * 2 + for offset := 0; offset+entrySize <= len(data); offset += entrySize { + key := readAuxvWord(data[offset:offset+wordSize], order) + val := readAuxvWord(data[offset+wordSize:offset+entrySize], order) + if key == 0 { + break + } + if key == 17 { + return val, nil + } + } + return 0, errors.New("AT_CLKTCK not found") +} + +func readAuxvWord(data []byte, order binary.ByteOrder) uint64 { + if len(data) == 4 { + return uint64(order.Uint32(data)) + } + return order.Uint64(data) +} + +func nativeEndian() binary.ByteOrder { + var n uint16 = 1 + if *(*byte)(unsafe.Pointer(&n)) == 1 { + return binary.LittleEndian + } + return binary.BigEndian +} + +func cpuUsageOverDuration(delta float64, sleep time.Duration) float64 { + if delta < 0 || sleep <= 0 { + return 0 + } + seconds := sleep.Seconds() + if seconds <= 0 { + return 0 + } + return delta / seconds * 100 +} + +func cpuUsagePercent(busyTicks, totalTicks float64) float64 { + if busyTicks < 0 || totalTicks <= 0 { + return 0 + } + return 100 * busyTicks / totalTicks +} + func getCPUSample() (idle, total uint64) { contents, err := ioutil.ReadFile("/proc/stat") if err != nil { @@ -59,12 +137,15 @@ func getCPUSample() (idle, total uint64) { lines := strings.Split(string(contents), "\n") for _, line := range lines { fields := strings.Fields(line) + if len(fields) == 0 { + continue + } if fields[0] == "cpu" { numFields := len(fields) for i := 1; i < numFields; i++ { val, err := strconv.ParseUint(fields[i], 10, 64) if err != nil { - fmt.Println("Error: ", i, fields[i], err) + continue } total += val // tally up all the numbers to get total ticks if i == 4 || i == 5 { // idle is the 5th field in the cpu line @@ -117,31 +198,45 @@ func getCPUSampleByPid(pid int) float64 { } else { iotime = 0 // e.g. SmartOS containers } - return utime/float64(clockTicks) + stime/float64(clockTicks) + iotime/float64(clockTicks) + ticks := float64(clockTicks()) + return utime/ticks + stime/ticks + iotime/ticks } func CpuUsageByPid(pid int, sleep time.Duration) float64 { + if sleep <= 0 { + return 0 + } total1 := getCPUSampleByPid(pid) time.Sleep(sleep) total2 := getCPUSampleByPid(pid) - return (total2 - total1) / sleep.Seconds() * 100 + return cpuUsageOverDuration(total2-total1, sleep) } // CpuUsage 获取CPU使用量 func CpuUsage(sleep time.Duration) float64 { + if sleep <= 0 { + return 0 + } idle0, total0 := getCPUSample() time.Sleep(sleep) idle1, total1 := getCPUSample() idleTicks := float64(idle1 - idle0) totalTicks := float64(total1 - total0) - cpuUsage := 100 * (totalTicks - idleTicks) / totalTicks + cpuUsage := cpuUsagePercent(totalTicks-idleTicks, totalTicks) return cpuUsage //fmt.Printf("CPU usage is %f%% [busy: %f, total: %f]\n", cpuUsage, totalTicks-idleTicks, totalTicks) } func DiskUsage(path string) (disk DiskStatus) { + disk, _ = DiskUsageE(path) + return +} + +func DiskUsageE(path string) (disk DiskStatus, err error) { + if path == "" { + path = "." + } fs := syscall.Statfs_t{} - err := syscall.Statfs(path, &fs) - if err != nil { + if err = syscall.Statfs(path, &fs); err != nil { return } disk.All = fs.Blocks * uint64(fs.Bsize) diff --git a/os_unix_test.go b/os_unix_test.go new file mode 100644 index 0000000..0d8e01a --- /dev/null +++ b/os_unix_test.go @@ -0,0 +1,36 @@ +//go:build linux +// +build linux + +package staros + +import ( + "testing" + "time" +) + +func TestCPUUsageOverDurationGuardsZeroOrNegativeWindow(t *testing.T) { + if got := cpuUsageOverDuration(10, 0); got != 0 { + t.Fatalf("expected zero-window cpu usage to clamp to 0, got %v", got) + } + if got := cpuUsageOverDuration(10, -time.Millisecond); got != 0 { + t.Fatalf("expected negative-window cpu usage to clamp to 0, got %v", got) + } + if got := cpuUsageOverDuration(-1, time.Second); got != 0 { + t.Fatalf("expected negative delta cpu usage to clamp to 0, got %v", got) + } + if got := cpuUsageOverDuration(0.5, time.Second); got != 50 { + t.Fatalf("expected normal cpu usage calculation, got %v", got) + } +} + +func TestCPUUsagePercentGuardsInvalidSamples(t *testing.T) { + if got := cpuUsagePercent(1, 0); got != 0 { + t.Fatalf("expected zero total ticks to clamp to 0, got %v", got) + } + if got := cpuUsagePercent(-1, 4); got != 0 { + t.Fatalf("expected negative busy ticks to clamp to 0, got %v", got) + } + if got := cpuUsagePercent(1, 4); got != 25 { + t.Fatalf("expected normal cpu percent calculation, got %v", got) + } +} diff --git a/os_windows.go b/os_windows.go index 128abb2..1f8fc01 100644 --- a/os_windows.go +++ b/os_windows.go @@ -1,9 +1,9 @@ +//go:build windows // +build windows package staros import ( - "log" "syscall" "time" "unsafe" @@ -27,27 +27,33 @@ func IsRoot() bool { return wincmd.Isas() } - func DiskUsage(path string) (disk DiskStatus) { - kernel32, err := syscall.LoadLibrary("Kernel32.dll") - if err != nil { - log.Panic(err) - } - defer syscall.FreeLibrary(kernel32) - GetDiskFreeSpaceEx, err := syscall.GetProcAddress(syscall.Handle(kernel32), "GetDiskFreeSpaceExW") + disk, _ = DiskUsageE(path) + return +} - if err != nil { - log.Panic(err) +func DiskUsageE(path string) (disk DiskStatus, err error) { + if path == "" { + path = "." } - lpFreeBytesAvailable := int64(0) lpTotalNumberOfBytes := int64(0) lpTotalNumberOfFreeBytes := int64(0) - syscall.Syscall6(uintptr(GetDiskFreeSpaceEx), 4, - uintptr(unsafe.Pointer(syscall.StringToUTF16Ptr("C:"))), + + path16, err := syscall.UTF16PtrFromString(path) + if err != nil { + return + } + r1, _, callErr := syscall.NewLazyDLL("kernel32.dll").NewProc("GetDiskFreeSpaceExW").Call( + uintptr(unsafe.Pointer(path16)), uintptr(unsafe.Pointer(&lpFreeBytesAvailable)), uintptr(unsafe.Pointer(&lpTotalNumberOfBytes)), - uintptr(unsafe.Pointer(&lpTotalNumberOfFreeBytes)), 0, 0) + uintptr(unsafe.Pointer(&lpTotalNumberOfFreeBytes)), + ) + if r1 == 0 { + err = callErr + return + } disk.Free = uint64(lpTotalNumberOfFreeBytes) disk.Used = uint64(lpTotalNumberOfBytes - lpTotalNumberOfFreeBytes) disk.All = uint64(lpTotalNumberOfBytes) diff --git a/process.go b/process.go index 39b75f7..3f91040 100644 --- a/process.go +++ b/process.go @@ -7,156 +7,334 @@ import ( "io" "os" "os/exec" - "strings" "sync" "sync/atomic" - "syscall" "time" ) +var errNilCommand = errors.New("nil command") +var errCommandStdinUnavailable = errors.New("command stdin is not available") +var errCommandProcessNotStarted = errors.New("command process is not started") +var errCommandAlreadyStarted = errors.New("command already started") +var errCommandAlreadyReleased = errors.New("command already released") +var errCommandStdinClosed = errors.New("command stdin is closed") +var errCommandAlreadyDetached = errors.New("command already detached") +var errCommandDetached = errors.New("command already detached") +var errCommandRedirectNil = errors.New("command redirect target is nil") + +const starCmdUnknownExitCode = -999 +const starCmdStreamBuffer = 64 + +type starCmdStream int + +const ( + starCmdStdout starCmdStream = iota + starCmdStderr +) + +// StarCmdOutputStream identifies which process stream produced a chunk. +type StarCmdOutputStream int + +const ( + StarCmdOutputStdout StarCmdOutputStream = iota + StarCmdOutputStderr +) + +// StarCmdOutput is a streamed stdout/stderr chunk. +type StarCmdOutput struct { + Stream StarCmdOutputStream + Data []byte +} + +type starCmdWriter struct { + cmd *StarCmd + stream starCmdStream +} + +func (writer starCmdWriter) Write(data []byte) (int, error) { + if writer.cmd == nil { + return 0, errNilCommand + } + writer.cmd.lock.Lock() + writer.cmd.ensureBuffers() + var redirect io.Writer + switch writer.stream { + case starCmdStdout: + if _, err := writer.cmd.stdoutBuf.Write(data); err != nil { + writer.cmd.lock.Unlock() + return 0, err + } + writer.cmd.stdout = append(writer.cmd.stdout, data...) + writer.cmd.publishStreamLocked(starCmdStdout, data) + redirect = writer.cmd.stdoutRedirect + case starCmdStderr: + if _, err := writer.cmd.stderrBuf.Write(data); err != nil { + writer.cmd.lock.Unlock() + return 0, err + } + writer.cmd.errout = append(writer.cmd.errout, data...) + writer.cmd.publishStreamLocked(starCmdStderr, data) + redirect = writer.cmd.stderrRedirect + default: + writer.cmd.lock.Unlock() + return 0, errors.New("unknown command stream") + } + writer.cmd.lock.Unlock() + if redirect != nil { + writer.cmd.redirectLock.Lock() + n, err := redirect.Write(data) + writer.cmd.redirectLock.Unlock() + if err != nil { + return n, err + } + if n != len(data) { + return n, io.ErrShortWrite + } + } + return len(data), nil +} + //StarCmd Is Here type StarCmd struct { - CMD *exec.Cmd - outfile io.ReadCloser - infile io.WriteCloser - errfile io.ReadCloser - running int32 + CMD *exec.Cmd + infile io.WriteCloser + inclosed bool + running int32 + started int32 + released int32 + detached int32 //Store AlL of the Standed Outputs stdout []byte //Store All of the Standed Errors - errout []byte - runerr error - exitcode int - stdoutBuf *bytes.Buffer - stderrBuf *bytes.Buffer - stdoutpoint int - stderrpoint int - lock sync.Mutex - prewrite []string - prewritetime time.Duration - stopctxfunc context.CancelFunc - stopctx context.Context + errout []byte + runerr error + exitcode int + stdoutBuf *bytes.Buffer + stderrBuf *bytes.Buffer + lock sync.Mutex + prewrite []string + prewritetime time.Duration + stopctxfunc context.CancelFunc + stopctx context.Context + doneOnce sync.Once + done chan struct{} + resultOnce sync.Once + resultDone chan struct{} + stdoutStream []chan []byte + stderrStream []chan []byte + outputStream []chan StarCmdOutput + streamClosed bool + stdoutRedirect io.Writer + stderrRedirect io.Writer + redirectLock sync.Mutex + closeAfter []io.Closer +} + +func (starcli *StarCmd) ensureBuffers() { + if starcli.stdoutBuf == nil { + starcli.stdoutBuf = bytes.NewBuffer(make([]byte, 0)) + } + if starcli.stderrBuf == nil { + starcli.stderrBuf = bytes.NewBuffer(make([]byte, 0)) + } +} + +func (starcli *StarCmd) ensureStopContext() { + if starcli.stopctx == nil || starcli.stopctxfunc == nil { + starcli.stopctx, starcli.stopctxfunc = context.WithCancel(context.Background()) + } + if starcli.done == nil { + starcli.done = make(chan struct{}) + } + if starcli.resultDone == nil { + starcli.resultDone = make(chan struct{}) + } +} + +func (starcli *StarCmd) ensureResultDone() <-chan struct{} { + if starcli == nil { + closed := make(chan struct{}) + close(closed) + return closed + } + starcli.lock.Lock() + starcli.ensureStopContext() + done := starcli.resultDone + starcli.lock.Unlock() + return done +} + +func (starcli *StarCmd) signalResultDone() { + if starcli == nil { + return + } + starcli.resultOnce.Do(func() { + starcli.lock.Lock() + starcli.ensureStopContext() + done := starcli.resultDone + starcli.lock.Unlock() + close(done) + }) +} + +func (starcli *StarCmd) finish() { + if starcli == nil { + return + } + starcli.setRunning(false) + starcli.signalResultDone() + if starcli.stopctxfunc != nil { + starcli.stopctxfunc() + } + starcli.doneOnce.Do(func() { + starcli.lock.Lock() + done := starcli.done + stdoutStream := starcli.stdoutStream + stderrStream := starcli.stderrStream + outputStream := starcli.outputStream + infile := starcli.infile + closeAfter := starcli.closeAfter + starcli.stdoutStream = nil + starcli.stderrStream = nil + starcli.outputStream = nil + starcli.closeAfter = nil + starcli.streamClosed = true + if infile != nil && !starcli.inclosed { + starcli.inclosed = true + } else { + infile = nil + } + starcli.lock.Unlock() + for _, stream := range stdoutStream { + close(stream) + } + for _, stream := range stderrStream { + close(stream) + } + for _, stream := range outputStream { + close(stream) + } + if infile != nil { + _ = infile.Close() + } + for _, closer := range closeAfter { + _ = closer.Close() + } + if done != nil { + close(done) + } + }) +} + +func (starcli *StarCmd) publishStreamLocked(stream starCmdStream, data []byte) { + if starcli == nil || starcli.streamClosed { + return + } + switch stream { + case starCmdStdout: + for _, receiver := range starcli.stdoutStream { + select { + case receiver <- append([]byte(nil), data...): + default: + } + } + for _, receiver := range starcli.outputStream { + select { + case receiver <- StarCmdOutput{Stream: StarCmdOutputStdout, Data: append([]byte(nil), data...)}: + default: + } + } + case starCmdStderr: + for _, receiver := range starcli.stderrStream { + select { + case receiver <- append([]byte(nil), data...): + default: + } + } + for _, receiver := range starcli.outputStream { + select { + case receiver <- StarCmdOutput{Stream: StarCmdOutputStderr, Data: append([]byte(nil), data...)}: + default: + } + } + } +} + +func (starcli *StarCmd) registerByteStream(selectStream starCmdStream) <-chan []byte { + stream := make(chan []byte, starCmdStreamBuffer) + if starcli == nil { + close(stream) + return stream + } + starcli.lock.Lock() + if starcli.streamClosed { + close(stream) + } else { + switch selectStream { + case starCmdStdout: + starcli.stdoutStream = append(starcli.stdoutStream, stream) + case starCmdStderr: + starcli.stderrStream = append(starcli.stderrStream, stream) + default: + close(stream) + } + } + starcli.lock.Unlock() + return stream +} + +func (starcli *StarCmd) ensureConfigurable() error { + if starcli == nil || starcli.CMD == nil { + return errNilCommand + } + if atomic.LoadInt32(&starcli.started) != 0 { + return errCommandAlreadyStarted + } + if atomic.LoadInt32(&starcli.detached) != 0 { + return errCommandDetached + } + return nil } func Command(command string, args ...string) (*StarCmd, error) { - var err error - shell := new(StarCmd) - shell.running = 0 - shell.prewritetime = time.Millisecond * 200 - shell.stdoutBuf = bytes.NewBuffer(make([]byte, 0)) - shell.stderrBuf = bytes.NewBuffer(make([]byte, 0)) - shell.stopctx, shell.stopctxfunc = context.WithCancel(context.Background()) - cmd := exec.Command(command, args...) - shell.CMD = cmd - shell.infile, err = shell.CMD.StdinPipe() - if err != nil { - return shell, err - } - shell.errfile, err = shell.CMD.StderrPipe() - if err != nil { - return shell, err - } - shell.outfile, err = shell.CMD.StdoutPipe() - if err != nil { - return shell, err - } - shell.runerr = nil - shell.exitcode = -999 - return shell, nil + return newStarCmd(exec.Command(command, args...)) } + func CommandContext(ctx context.Context, command string, args ...string) (*StarCmd, error) { + return newStarCmd(exec.CommandContext(ctx, command, args...)) +} + +func newStarCmd(cmd *exec.Cmd) (*StarCmd, error) { var err error - shell := new(StarCmd) - shell.running = 0 - shell.stdoutBuf = bytes.NewBuffer(make([]byte, 0)) - shell.stderrBuf = bytes.NewBuffer(make([]byte, 0)) - shell.prewritetime = time.Millisecond * 200 + shell := &StarCmd{ + CMD: cmd, + prewritetime: time.Millisecond * 200, + stdoutBuf: bytes.NewBuffer(make([]byte, 0)), + stderrBuf: bytes.NewBuffer(make([]byte, 0)), + done: make(chan struct{}), + resultDone: make(chan struct{}), + exitcode: starCmdUnknownExitCode, + } shell.stopctx, shell.stopctxfunc = context.WithCancel(context.Background()) - cmd := exec.CommandContext(ctx, command, args...) - shell.CMD = cmd shell.infile, err = shell.CMD.StdinPipe() if err != nil { return shell, err } - shell.errfile, err = shell.CMD.StderrPipe() - if err != nil { - return shell, err - } - shell.outfile, err = shell.CMD.StdoutPipe() - if err != nil { - return shell, err - } - shell.runerr = nil - shell.exitcode = -999 + shell.CMD.Stdout = starCmdWriter{cmd: shell, stream: starCmdStdout} + shell.CMD.Stderr = starCmdWriter{cmd: shell, stream: starCmdStderr} return shell, nil } -func (starcli *StarCmd) queryStdout(ctx context.Context) { - for starcli.IsRunning() && starcli.CMD != nil { - select { - case <-ctx.Done(): - return - default: - } - out := make([]byte, 65535) - n, err := starcli.outfile.Read(out) - if n != 0 { - starcli.lock.Lock() - starcli.stdoutBuf.Write(out[:n]) - starcli.lock.Unlock() - for _, v := range out[:n] { - starcli.stdout = append(starcli.stdout, v) - } - } - if err != nil { - if err == io.EOF { - break - } else { - if !strings.Contains(err.Error(), "file already closed") { - starcli.runerr = err - } - return - } - } - } -} - -func (starcli *StarCmd) queryStderr(ctx context.Context) { - for starcli.IsRunning() && starcli.CMD != nil { - select { - case <-ctx.Done(): - return - default: - } - out := make([]byte, 65535) - n, err := starcli.errfile.Read(out) - if n != 0 { - starcli.lock.Lock() - starcli.stderrBuf.Write(out[:n]) - starcli.lock.Unlock() - for _, v := range out[:n] { - starcli.errout = append(starcli.errout, v) - } - } - if err != nil { - if err == io.EOF { - break - } else { - if !strings.Contains(err.Error(), "file already closed") { - starcli.runerr = err - } - return - } - } - } - return -} func (starcli *StarCmd) NowLineOutput() (string, error) { + if starcli == nil { + return "", errNilCommand + } starcli.lock.Lock() + defer starcli.lock.Unlock() + starcli.ensureBuffers() buf, _ := starcli.stdoutBuf.ReadBytes('\n') buferr, _ := starcli.stderrBuf.ReadBytes(byte('\n')) - starcli.lock.Unlock() if len(buferr) != 0 { return string(buf), errors.New(string(buferr)) } @@ -164,15 +342,23 @@ func (starcli *StarCmd) NowLineOutput() (string, error) { } func (starcli *StarCmd) NowLineStdOut() string { + if starcli == nil { + return "" + } starcli.lock.Lock() defer starcli.lock.Unlock() + starcli.ensureBuffers() buf, _ := starcli.stdoutBuf.ReadBytes('\n') return string(buf) } func (starcli *StarCmd) NowLineStdErr() error { + if starcli == nil { + return errNilCommand + } starcli.lock.Lock() defer starcli.lock.Unlock() + starcli.ensureBuffers() buferr, _ := starcli.stderrBuf.ReadBytes(byte('\n')) if len(buferr) != 0 { return errors.New(string(buferr)) @@ -181,21 +367,24 @@ func (starcli *StarCmd) NowLineStdErr() error { } func (starcli *StarCmd) NowAllOutput() (string, error) { + if starcli == nil { + return "", errNilCommand + } var outstr string starcli.lock.Lock() + defer starcli.lock.Unlock() + starcli.ensureBuffers() buf := make([]byte, starcli.stdoutBuf.Len()) n, _ := starcli.stdoutBuf.Read(buf) - starcli.lock.Unlock() + runerr := starcli.runerr if n != 0 { outstr = string(buf[:n]) } - if starcli.runerr != nil { - return outstr, starcli.runerr + if runerr != nil { + return outstr, runerr } - starcli.lock.Lock() buf = make([]byte, starcli.stderrBuf.Len()) n, _ = starcli.stderrBuf.Read(buf) - starcli.lock.Unlock() if n != 0 { return outstr, errors.New(string(buf[:n])) } @@ -203,11 +392,15 @@ func (starcli *StarCmd) NowAllOutput() (string, error) { } func (starcli *StarCmd) NowStdOut() string { + if starcli == nil { + return "" + } var outstr string starcli.lock.Lock() + defer starcli.lock.Unlock() + starcli.ensureBuffers() buf := make([]byte, starcli.stdoutBuf.Len()) n, _ := starcli.stdoutBuf.Read(buf) - starcli.lock.Unlock() if n != 0 { outstr = string(buf[:n]) } @@ -215,10 +408,14 @@ func (starcli *StarCmd) NowStdOut() string { } func (starcli *StarCmd) NowStdErr() error { + if starcli == nil { + return errNilCommand + } starcli.lock.Lock() + defer starcli.lock.Unlock() + starcli.ensureBuffers() buf := make([]byte, starcli.stderrBuf.Len()) n, _ := starcli.stderrBuf.Read(buf) - starcli.lock.Unlock() if n != 0 { return errors.New(string(buf[:n])) } @@ -226,6 +423,11 @@ func (starcli *StarCmd) NowStdErr() error { } func (starcli *StarCmd) AllOutPut() (string, error) { + if starcli == nil { + return "", errNilCommand + } + starcli.lock.Lock() + defer starcli.lock.Unlock() err := starcli.runerr if err == nil && len(starcli.errout) != 0 { err = errors.New(string(starcli.errout)) @@ -234,10 +436,20 @@ func (starcli *StarCmd) AllOutPut() (string, error) { } func (starcli *StarCmd) AllStdOut() string { + if starcli == nil { + return "" + } + starcli.lock.Lock() + defer starcli.lock.Unlock() return string(starcli.stdout) } func (starcli *StarCmd) AllStdErr() error { + if starcli == nil { + return errNilCommand + } + starcli.lock.Lock() + defer starcli.lock.Unlock() err := starcli.runerr if err == nil && len(starcli.errout) != 0 { err = errors.New(string(starcli.errout)) @@ -246,102 +458,471 @@ func (starcli *StarCmd) AllStdErr() error { } func (starcli *StarCmd) setRunning(alive bool) { - if alive { - val := atomic.LoadInt32(&starcli.running) - if val == 0 { - atomic.AddInt32(&starcli.running, 1) - } else { - atomic.AddInt32(&starcli.running, 1-val) - } + if starcli == nil { return } - val := atomic.LoadInt32(&starcli.running) - if val == 1 { - atomic.AddInt32(&starcli.running, -1) - } else { - atomic.AddInt32(&starcli.running, -val) + if alive { + atomic.StoreInt32(&starcli.running, 1) + return } + atomic.StoreInt32(&starcli.running, 0) } func (starcli *StarCmd) Start() error { + if starcli == nil || starcli.CMD == nil { + return errNilCommand + } + if atomic.LoadInt32(&starcli.detached) != 0 { + return errCommandDetached + } + starcli.lock.Lock() + starcli.ensureBuffers() + starcli.ensureStopContext() + starcli.lock.Unlock() + if !atomic.CompareAndSwapInt32(&starcli.started, 0, 1) { + return errCommandAlreadyStarted + } if err := starcli.CMD.Start(); err != nil { + starcli.lock.Lock() + starcli.runerr = err + starcli.exitcode = -1 + starcli.lock.Unlock() + starcli.signalResultDone() + starcli.finish() return err } starcli.setRunning(true) go func() { err := starcli.CMD.Wait() if err != nil { + starcli.lock.Lock() starcli.runerr = err + starcli.lock.Unlock() } - starcli.stopctxfunc() - starcli.setRunning(false) if starcli.CMD.ProcessState != nil { - starcli.exitcode = starcli.CMD.ProcessState.Sys().(syscall.WaitStatus).ExitStatus() + starcli.lock.Lock() + starcli.exitcode = starcli.CMD.ProcessState.ExitCode() + starcli.lock.Unlock() } + starcli.signalResultDone() + starcli.finish() }() - go starcli.queryStdout(starcli.stopctx) - go starcli.queryStderr(starcli.stopctx) go func(ctx context.Context) { - if len(starcli.prewrite) != 0 { - for _, v := range starcli.prewrite { - select { - case <-ctx.Done(): - return - default: - break - } - starcli.WriteCmd(v) - time.Sleep(starcli.prewritetime) + starcli.lock.Lock() + prewrite := append([]string(nil), starcli.prewrite...) + prewritetime := starcli.prewritetime + starcli.lock.Unlock() + for _, v := range prewrite { + select { + case <-ctx.Done(): + return + default: } + _ = starcli.WriteCmdE(v) + time.Sleep(prewritetime) } }(starcli.stopctx) return nil } func (starcli *StarCmd) IsRunning() bool { + if starcli == nil { + return false + } return 0 != atomic.LoadInt32(&starcli.running) } +func (starcli *StarCmd) runError() error { + if starcli == nil { + return errNilCommand + } + starcli.lock.Lock() + defer starcli.lock.Unlock() + return starcli.runerr +} + +func (starcli *StarCmd) ensureWaitable() error { + if starcli == nil || starcli.CMD == nil { + return errNilCommand + } + if atomic.LoadInt32(&starcli.started) == 0 { + return errCommandProcessNotStarted + } + return nil +} + +// Stopped returns a channel that is closed after the command reaches its final state. +func (starcli *StarCmd) Stopped() <-chan struct{} { + if starcli == nil { + closed := make(chan struct{}) + close(closed) + return closed + } + starcli.lock.Lock() + starcli.ensureStopContext() + done := starcli.done + starcli.lock.Unlock() + return done +} + +// Stoped returns a channel that is closed after the command reaches its final state. +// +// Deprecated: use Stopped. func (starcli *StarCmd) Stoped() <-chan struct{} { - return starcli.stopctx.Done() + return starcli.Stopped() +} + +// Wait blocks until the command reaches its final state and returns the process wait error. +func (starcli *StarCmd) Wait() error { + if err := starcli.ensureWaitable(); err != nil { + return err + } + <-starcli.ensureResultDone() + return starcli.runError() +} + +// WaitContext blocks until the command reaches its final state or ctx is done. +func (starcli *StarCmd) WaitContext(ctx context.Context) error { + if err := starcli.ensureWaitable(); err != nil { + return err + } + if ctx == nil { + return starcli.Wait() + } + resultDone := starcli.ensureResultDone() + select { + case <-resultDone: + return starcli.runError() + default: + } + select { + case <-resultDone: + return starcli.runError() + case <-ctx.Done(): + select { + case <-resultDone: + return starcli.runError() + default: + } + return ctx.Err() + } +} + +// WaitTimeout blocks until the command reaches its final state or tm elapses. +func (starcli *StarCmd) WaitTimeout(tm time.Duration) error { + if err := starcli.ensureWaitable(); err != nil { + return err + } + if tm <= 0 { + select { + case <-starcli.ensureResultDone(): + return starcli.runError() + default: + return ERR_TIMEOUT + } + } + timer := time.NewTimer(tm) + defer timer.Stop() + resultDone := starcli.ensureResultDone() + select { + case <-resultDone: + return starcli.runError() + case <-timer.C: + select { + case <-resultDone: + return starcli.runError() + default: + return ERR_TIMEOUT + } + } +} + +// StdoutChan returns a channel that receives future stdout chunks until Stopped closes. +func (starcli *StarCmd) StdoutChan() <-chan []byte { + return starcli.registerByteStream(starCmdStdout) +} + +// StderrChan returns a channel that receives future stderr chunks until Stopped closes. +func (starcli *StarCmd) StderrChan() <-chan []byte { + return starcli.registerByteStream(starCmdStderr) +} + +// OutputChan returns a channel that receives future stdout and stderr chunks until Stopped closes. +func (starcli *StarCmd) OutputChan() <-chan StarCmdOutput { + stream := make(chan StarCmdOutput, starCmdStreamBuffer) + if starcli == nil { + close(stream) + return stream + } + starcli.lock.Lock() + if starcli.streamClosed { + close(stream) + } else { + starcli.outputStream = append(starcli.outputStream, stream) + } + starcli.lock.Unlock() + return stream +} + +// RedirectStdout mirrors stdout into writer while keeping StarCmd output capture enabled. +func (starcli *StarCmd) RedirectStdout(writer io.Writer) error { + if writer == nil { + return errCommandRedirectNil + } + if err := starcli.ensureConfigurable(); err != nil { + return err + } + starcli.lock.Lock() + starcli.stdoutRedirect = writer + starcli.lock.Unlock() + return nil +} + +// RedirectStderr mirrors stderr into writer while keeping StarCmd error capture enabled. +func (starcli *StarCmd) RedirectStderr(writer io.Writer) error { + if writer == nil { + return errCommandRedirectNil + } + if err := starcli.ensureConfigurable(); err != nil { + return err + } + starcli.lock.Lock() + starcli.stderrRedirect = writer + starcli.lock.Unlock() + return nil +} + +// RedirectOutput mirrors stdout and stderr into writer while keeping StarCmd capture enabled. +func (starcli *StarCmd) RedirectOutput(writer io.Writer) error { + if writer == nil { + return errCommandRedirectNil + } + if err := starcli.ensureConfigurable(); err != nil { + return err + } + starcli.lock.Lock() + starcli.stdoutRedirect = writer + starcli.stderrRedirect = writer + starcli.lock.Unlock() + return nil +} + +// RedirectStdin replaces the managed stdin pipe with reader. +func (starcli *StarCmd) RedirectStdin(reader io.Reader) error { + if reader == nil { + return errCommandRedirectNil + } + if err := starcli.ensureConfigurable(); err != nil { + return err + } + starcli.lock.Lock() + if starcli.infile != nil && !starcli.inclosed { + if err := starcli.infile.Close(); err != nil { + starcli.lock.Unlock() + return err + } + } + starcli.CMD.Stdin = reader + starcli.infile = nil + starcli.inclosed = true + starcli.lock.Unlock() + return nil +} + +func (starcli *StarCmd) addCloseAfter(closer io.Closer) { + if starcli == nil || closer == nil { + return + } + starcli.lock.Lock() + starcli.closeAfter = append(starcli.closeAfter, closer) + starcli.lock.Unlock() +} + +// RedirectStdoutFile mirrors stdout into path while keeping StarCmd output capture enabled. +func (starcli *StarCmd) RedirectStdoutFile(path string) error { + if err := starcli.ensureConfigurable(); err != nil { + return err + } + file, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644) + if err != nil { + return err + } + if err := starcli.RedirectStdout(file); err != nil { + _ = file.Close() + return err + } + starcli.addCloseAfter(file) + return nil +} + +// RedirectStderrFile mirrors stderr into path while keeping StarCmd error capture enabled. +func (starcli *StarCmd) RedirectStderrFile(path string) error { + if err := starcli.ensureConfigurable(); err != nil { + return err + } + file, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644) + if err != nil { + return err + } + if err := starcli.RedirectStderr(file); err != nil { + _ = file.Close() + return err + } + starcli.addCloseAfter(file) + return nil +} + +// RedirectOutputFile mirrors stdout and stderr into path while keeping StarCmd capture enabled. +func (starcli *StarCmd) RedirectOutputFile(path string) error { + if err := starcli.ensureConfigurable(); err != nil { + return err + } + file, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644) + if err != nil { + return err + } + if err := starcli.RedirectOutput(file); err != nil { + _ = file.Close() + return err + } + starcli.addCloseAfter(file) + return nil +} + +// RedirectStdinFile replaces the managed stdin pipe with path opened for reading. +func (starcli *StarCmd) RedirectStdinFile(path string) error { + if err := starcli.ensureConfigurable(); err != nil { + return err + } + file, err := os.Open(path) + if err != nil { + return err + } + if err := starcli.RedirectStdin(file); err != nil { + _ = file.Close() + return err + } + starcli.addCloseAfter(file) + return nil } func (starcli *StarCmd) Exec(cmd string, wait int) (string, error) { - starcli.infile.Write([]byte(cmd + "\n")) + if err := starcli.WriteCmdE(cmd); err != nil { + return "", err + } time.Sleep(time.Millisecond * time.Duration(wait)) return starcli.NowAllOutput() } func (starcli *StarCmd) WriteCmd(cmdstr string) { - starcli.infile.Write([]byte(cmdstr + "\n")) + _ = starcli.WriteCmdE(cmdstr) +} + +// WriteStdinE writes raw bytes to stdin without appending a newline. +func (starcli *StarCmd) WriteStdinE(data []byte) error { + if starcli == nil { + return errNilCommand + } + starcli.lock.Lock() + infile := starcli.infile + inclosed := starcli.inclosed + starcli.lock.Unlock() + if infile == nil { + return errCommandStdinUnavailable + } + if inclosed { + return errCommandStdinClosed + } + _, err := infile.Write(data) + return err +} + +// WriteStdinStringE writes raw text to stdin without appending a newline. +func (starcli *StarCmd) WriteStdinStringE(data string) error { + return starcli.WriteStdinE([]byte(data)) +} + +// WriteStdinLineE writes text to stdin and appends one newline. +func (starcli *StarCmd) WriteStdinLineE(data string) error { + return starcli.WriteStdinStringE(data + "\n") +} + +func (starcli *StarCmd) WriteCmdE(cmdstr string) error { + return starcli.WriteStdinLineE(cmdstr) +} + +func (starcli *StarCmd) CloseStdin() { + _ = starcli.CloseStdinE() +} + +func (starcli *StarCmd) CloseStdinE() error { + if starcli == nil { + return errNilCommand + } + starcli.lock.Lock() + infile := starcli.infile + if infile == nil { + starcli.lock.Unlock() + return errCommandStdinUnavailable + } + if starcli.inclosed { + starcli.lock.Unlock() + return errCommandStdinClosed + } + starcli.inclosed = true + starcli.lock.Unlock() + return infile.Close() } func (starcli *StarCmd) PreWrite(cmd ...string) { + if starcli == nil { + return + } + starcli.lock.Lock() + defer starcli.lock.Unlock() for _, v := range cmd { starcli.prewrite = append(starcli.prewrite, v) } } func (starcli *StarCmd) PreWriteInterval(dt time.Duration) { + if starcli == nil { + return + } + starcli.lock.Lock() + defer starcli.lock.Unlock() starcli.prewritetime = dt } func (starcli *StarCmd) ExitCode() int { + if starcli == nil { + return starCmdUnknownExitCode + } + starcli.lock.Lock() + defer starcli.lock.Unlock() return starcli.exitcode } func (starcli *StarCmd) Kill() error { + if starcli == nil || starcli.CMD == nil || starcli.CMD.Process == nil { + return errCommandProcessNotStarted + } err := starcli.CMD.Process.Kill() if err != nil { return err } - starcli.setRunning(false) return nil } func (starcli *StarCmd) GetPid() int { + if starcli == nil || starcli.CMD == nil || starcli.CMD.Process == nil { + return -1 + } return starcli.CMD.Process.Pid } func (starcli *StarCmd) Signal(sig os.Signal) error { + if starcli == nil || starcli.CMD == nil || starcli.CMD.Process == nil { + return errCommandProcessNotStarted + } return starcli.CMD.Process.Signal(sig) } diff --git a/process_darwin.go b/process_darwin.go new file mode 100644 index 0000000..e4f7e49 --- /dev/null +++ b/process_darwin.go @@ -0,0 +1,73 @@ +//go:build darwin +// +build darwin + +package staros + +import "sync/atomic" + +func FindProcessByName(name string) (datas []Process, err error) { + return nil, ERR_UNSUPPORTED +} + +func FindProcess(compare func(Process) bool) (datas []Process, err error) { + return nil, ERR_UNSUPPORTED +} + +func FindProcessByPid(pid int64) (datas Process, err error) { + return datas, ERR_UNSUPPORTED +} + +func Daemon(path string, args ...string) (int, error) { + return -1, ERR_UNSUPPORTED +} + +func DaemonWithUser(uid, gid uint32, groups []uint32, path string, args ...string) (int, error) { + return -1, ERR_UNSUPPORTED +} + +func (starcli *StarCmd) SetRunUser(uid, gid uint32, groups []uint32) { + _ = starcli.SetRunUserE(uid, gid, groups) +} + +func (starcli *StarCmd) SetRunUserE(uid, gid uint32, groups []uint32) error { + if starcli == nil || starcli.CMD == nil { + return errNilCommand + } + if atomic.LoadInt32(&starcli.started) != 0 { + return errCommandAlreadyStarted + } + return ERR_UNSUPPORTED +} + +func (starcli *StarCmd) Release() error { + return starcli.ReleaseE() +} + +func (starcli *StarCmd) Detach() error { + return starcli.DetachE() +} + +func (starcli *StarCmd) ReleaseE() error { + if starcli == nil || starcli.CMD == nil { + return errNilCommand + } + return ERR_UNSUPPORTED +} + +func (starcli *StarCmd) DetachE() error { + if starcli == nil || starcli.CMD == nil { + return errNilCommand + } + return ERR_UNSUPPORTED +} + +func (starcli *StarCmd) SetKeepCaps() error { + if starcli == nil || starcli.CMD == nil { + return errNilCommand + } + return ERR_UNSUPPORTED +} + +func SetKeepCaps() error { + return ERR_UNSUPPORTED +} diff --git a/process_linux_test.go b/process_linux_test.go new file mode 100644 index 0000000..66ac4fd --- /dev/null +++ b/process_linux_test.go @@ -0,0 +1,110 @@ +//go:build linux +// +build linux + +package staros + +import ( + "errors" + "reflect" + "syscall" + "testing" +) + +func TestStarCmdSetKeepCapsConfiguresAmbientCaps(t *testing.T) { + command, args := testCommandArgs("exit 0") + cmd, err := Command(command, args...) + if err != nil { + t.Fatal(err) + } + + original := loadCurrentKeepCaps + loadCurrentKeepCaps = func() ([]uintptr, error) { + return []uintptr{7, 1, 7}, nil + } + t.Cleanup(func() { + loadCurrentKeepCaps = original + }) + + cmd.CMD.SysProcAttr = &syscall.SysProcAttr{ + AmbientCaps: []uintptr{9, 1}, + } + if err := cmd.SetKeepCaps(); err != nil { + t.Fatal(err) + } + + want := []uintptr{1, 7, 9} + if got := cmd.CMD.SysProcAttr.AmbientCaps; !reflect.DeepEqual(got, want) { + t.Fatalf("unexpected ambient caps: got=%v want=%v", got, want) + } +} + +func TestStarCmdSetKeepCapsPropagatesCapabilityReadError(t *testing.T) { + command, args := testCommandArgs("exit 0") + cmd, err := Command(command, args...) + if err != nil { + t.Fatal(err) + } + + wantErr := errors.New("capget failed") + original := loadCurrentKeepCaps + loadCurrentKeepCaps = func() ([]uintptr, error) { + return nil, wantErr + } + t.Cleanup(func() { + loadCurrentKeepCaps = original + }) + + if err := cmd.SetKeepCaps(); !errors.Is(err, wantErr) { + t.Fatalf("expected keepcaps read error, got %v", err) + } +} + +func TestStarCmdSetRunUserPreservesExistingSysProcAttr(t *testing.T) { + command, args := testCommandArgs("exit 0") + cmd, err := Command(command, args...) + if err != nil { + t.Fatal(err) + } + + original := loadCurrentKeepCaps + loadCurrentKeepCaps = func() ([]uintptr, error) { + return []uintptr{7, 1, 7}, nil + } + t.Cleanup(func() { + loadCurrentKeepCaps = original + }) + + cmd.CMD.SysProcAttr = &syscall.SysProcAttr{ + Pdeathsig: syscall.SIGTERM, + AmbientCaps: []uintptr{9}, + } + if err := cmd.SetKeepCaps(); err != nil { + t.Fatal(err) + } + + groups := []uint32{3, 4} + if err := cmd.SetRunUserE(1, 2, groups); err != nil { + t.Fatal(err) + } + groups[0] = 99 + + if got, want := cmd.CMD.SysProcAttr.AmbientCaps, []uintptr{1, 7, 9}; !reflect.DeepEqual(got, want) { + t.Fatalf("ambient caps lost after SetRunUserE: got=%v want=%v", got, want) + } + if got := cmd.CMD.SysProcAttr.Pdeathsig; got != syscall.SIGTERM { + t.Fatalf("expected Pdeathsig to be preserved, got %v", got) + } + if !cmd.CMD.SysProcAttr.Setsid { + t.Fatal("expected Setsid to be enabled") + } + cred := cmd.CMD.SysProcAttr.Credential + if cred == nil { + t.Fatal("expected credential to be configured") + } + if cred.Uid != 1 || cred.Gid != 2 { + t.Fatalf("unexpected credential ids: uid=%d gid=%d", cred.Uid, cred.Gid) + } + if got, want := cred.Groups, []uint32{3, 4}; !reflect.DeepEqual(got, want) { + t.Fatalf("unexpected credential groups: got=%v want=%v", got, want) + } +} diff --git a/process_test.go b/process_test.go index 047e722..f84943e 100644 --- a/process_test.go +++ b/process_test.go @@ -1,27 +1,813 @@ package staros import ( + "bytes" "context" - "fmt" + "encoding/base64" + "encoding/binary" + "errors" + "io/ioutil" + "os" + "path/filepath" + "runtime" + "strings" "testing" "time" + "unicode/utf16" ) -func Test_Process(t *testing.T) { - fmt.Println(FindProcessByPid(16652)) -} - -func Test_StarCmd(t *testing.T) { - ctx, _ := context.WithTimeout(context.Background(), time.Second*5) - cmd, _ := CommandContext(ctx, "cmd.exe", "/c", "ping -t 127.0.0.1") - cmd.Start() - for cmd.IsRunning() { - fmt.Print(cmd.NowLineOutput()) - time.Sleep(time.Millisecond * 50) +func testCommandArgs(script string) (string, []string) { + if runtime.GOOS == "windows" { + return "cmd.exe", []string{"/c", script} + } + return "sh", []string{"-c", script} +} + +func testWindowsPowerShellArgs(script string) (string, []string) { + utf16Script := utf16.Encode([]rune(script)) + encoded := make([]byte, len(utf16Script)*2) + for i, r := range utf16Script { + binary.LittleEndian.PutUint16(encoded[i*2:], uint16(r)) + } + return "powershell.exe", []string{"-NoProfile", "-EncodedCommand", base64.StdEncoding.EncodeToString(encoded)} +} + +type closeTrackingWriteCloser struct { + closed bool +} + +func (w *closeTrackingWriteCloser) Write(data []byte) (int, error) { + return len(data), nil +} + +func (w *closeTrackingWriteCloser) Close() error { + w.closed = true + return nil +} + +func TestStarCmdCapturesOutputAndExitCode(t *testing.T) { + script := "printf 'hello'; printf 'err' 1>&2" + command, args := testCommandArgs(script) + if runtime.GOOS == "windows" { + command, args = testWindowsPowerShellArgs("[Console]::Out.Write('hello'); [Console]::Error.Write('err')") + } + cmd, err := Command(command, args...) + if err != nil { + t.Fatal(err) + } + if err := cmd.Start(); err != nil { + t.Fatal(err) + } + <-cmd.Stoped() + out, outErr := cmd.AllOutPut() + if out != "hello" { + t.Fatalf("expected stdout %q, got %q", "hello", out) + } + if outErr == nil || outErr.Error() != "err" { + t.Fatalf("expected stderr error %q, got %v", "err", outErr) + } + if got := cmd.ExitCode(); got != 0 { + t.Fatalf("expected exit code 0, got %d", got) + } +} + +func TestStarCmdWaitReturnsProcessError(t *testing.T) { + command, args := testCommandArgs("exit 7") + cmd, err := Command(command, args...) + if err != nil { + t.Fatal(err) + } + if err := cmd.Wait(); !errors.Is(err, errCommandProcessNotStarted) { + t.Fatalf("expected errCommandProcessNotStarted before start, got %v", err) + } + if err := cmd.Start(); err != nil { + t.Fatal(err) + } + if err := cmd.Wait(); err == nil { + t.Fatal("expected wait error for non-zero exit") + } + if got := cmd.ExitCode(); got != 7 { + t.Fatalf("expected exit code 7, got %d", got) + } + if err := cmd.Wait(); err == nil { + t.Fatal("expected repeated Wait to keep final process error") + } +} + +func TestStarCmdWaitTimeoutAndContext(t *testing.T) { + command, args := testCommandArgs("sleep 1") + if runtime.GOOS == "windows" { + command, args = testCommandArgs("ping -n 2 127.0.0.1 >nul") + } + cmd, err := Command(command, args...) + if err != nil { + t.Fatal(err) + } + if err := cmd.Start(); err != nil { + t.Fatal(err) + } + if err := cmd.WaitTimeout(10 * time.Millisecond); !errors.Is(err, ERR_TIMEOUT) { + t.Fatalf("expected ERR_TIMEOUT, got %v", err) + } + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) + defer cancel() + if err := cmd.WaitContext(ctx); !errors.Is(err, context.DeadlineExceeded) { + t.Fatalf("expected context deadline, got %v", err) + } + if err := cmd.WaitTimeout(3 * time.Second); err != nil { + t.Fatalf("expected command to finish, got %v", err) + } +} + +func TestStarCmdWaitReturnsResultAfterProcessDone(t *testing.T) { + command, args := testCommandArgs("exit 0") + cmd, err := Command(command, args...) + if err != nil { + t.Fatal(err) + } + if err := cmd.Start(); err != nil { + t.Fatal(err) + } + if err := cmd.Wait(); err != nil { + t.Fatal(err) + } + if err := cmd.WaitTimeout(0); err != nil { + t.Fatalf("expected finished command to beat zero timeout, got %v", err) + } + ctx, cancel := context.WithCancel(context.Background()) + cancel() + if err := cmd.WaitContext(ctx); err != nil { + t.Fatalf("expected finished command to beat canceled context, got %v", err) + } +} + +func TestStarCmdWaitContextFinishedCommandWinsOverCanceledContext(t *testing.T) { + t.Run("success", func(t *testing.T) { + command, args := testCommandArgs("exit 0") + cmd, err := Command(command, args...) + if err != nil { + t.Fatal(err) + } + if err := cmd.Start(); err != nil { + t.Fatal(err) + } + if err := cmd.Wait(); err != nil { + t.Fatal(err) + } + ctx, cancel := context.WithCancel(context.Background()) + cancel() + if err := cmd.WaitContext(ctx); err != nil { + t.Fatalf("finished successful command should win over canceled context, got %v", err) + } + }) + t.Run("failed", func(t *testing.T) { + command, args := testCommandArgs("exit 7") + cmd, err := Command(command, args...) + if err != nil { + t.Fatal(err) + } + if err := cmd.Start(); err != nil { + t.Fatal(err) + } + waitErr := cmd.Wait() + if waitErr == nil { + t.Fatal("expected command wait error") + } + ctx, cancel := context.WithCancel(context.Background()) + cancel() + if err := cmd.WaitContext(ctx); err == nil || err.Error() != waitErr.Error() { + t.Fatalf("finished failed command should win over canceled context, got %v, want %v", err, waitErr) + } + }) +} + +func TestStarCmdStoppedAlias(t *testing.T) { + command, args := testCommandArgs("exit 0") + cmd, err := Command(command, args...) + if err != nil { + t.Fatal(err) + } + if err := cmd.Start(); err != nil { + t.Fatal(err) + } + select { + case <-cmd.Stopped(): + case <-time.After(time.Second): + t.Fatal("Stopped should close after command exits") + } + select { + case <-cmd.Stoped(): + case <-time.After(time.Second): + t.Fatal("Stoped compatibility alias should close after command exits") + } +} + +func TestStarCmdStopedPublishesFinalExitCode(t *testing.T) { + command, args := testCommandArgs("exit 7") + cmd, err := Command(command, args...) + if err != nil { + t.Fatal(err) + } + if err := cmd.Start(); err != nil { + t.Fatal(err) + } + <-cmd.Stoped() + if cmd.IsRunning() { + t.Fatal("command should not be running after Stoped closes") + } + if got := cmd.ExitCode(); got != 7 { + t.Fatalf("expected exit code 7, got %d", got) + } +} + +func TestStarCmdRejectsRepeatedStart(t *testing.T) { + command, args := testCommandArgs("exit 0") + cmd, err := Command(command, args...) + if err != nil { + t.Fatal(err) + } + if err := cmd.Start(); err != nil { + t.Fatal(err) + } + if err := cmd.Start(); !errors.Is(err, errCommandAlreadyStarted) { + t.Fatalf("expected errCommandAlreadyStarted, got %v", err) + } + <-cmd.Stoped() + if err := cmd.Start(); !errors.Is(err, errCommandAlreadyStarted) { + t.Fatalf("expected errCommandAlreadyStarted after exit, got %v", err) + } +} + +func TestStarCmdStartFailureClosesStoped(t *testing.T) { + cmd, err := Command("__staros_missing_command__") + if err != nil { + t.Fatal(err) + } + if err := cmd.Start(); err == nil { + t.Fatal("expected start failure") + } + select { + case <-cmd.Stoped(): + case <-time.After(time.Second): + t.Fatal("Stoped should close after start failure") + } + if cmd.IsRunning() { + t.Fatal("command should not be running after start failure") + } + if got := cmd.ExitCode(); got != -1 { + t.Fatalf("expected exit code -1 after start failure, got %d", got) + } +} + +func TestStarCmdCapturesLargeOutput(t *testing.T) { + expected := strings.Repeat("x", 256*1024) + script := "awk 'BEGIN{for(i=0;i<262144;i++) printf \"x\"}'" + command, args := testCommandArgs(script) + if runtime.GOOS == "windows" { + command, args = testWindowsPowerShellArgs("[Console]::Out.Write(('x' * 262144))") + } + cmd, err := Command(command, args...) + if err != nil { + t.Fatal(err) + } + if err := cmd.Start(); err != nil { + t.Fatal(err) + } + <-cmd.Stoped() + out, err := cmd.AllOutPut() + if err != nil { + t.Fatal(err) + } + if !bytes.Equal([]byte(out), []byte(expected)) { + t.Fatalf("expected %d stdout bytes, got %d", len(expected), len(out)) + } +} + +func TestStarCmdStreamsOutput(t *testing.T) { + script := "printf 'out'; printf 'err' 1>&2" + command, args := testCommandArgs(script) + if runtime.GOOS == "windows" { + command, args = testWindowsPowerShellArgs("[Console]::Out.Write('out'); [Console]::Error.Write('err')") + } + cmd, err := Command(command, args...) + if err != nil { + t.Fatal(err) + } + stdout := cmd.StdoutChan() + stderr := cmd.StderrChan() + output := cmd.OutputChan() + if err := cmd.Start(); err != nil { + t.Fatal(err) + } + var stdoutData, stderrData string + var outputData []StarCmdOutput + for stdout != nil || stderr != nil || output != nil { + select { + case data, ok := <-stdout: + if !ok { + stdout = nil + continue + } + stdoutData += string(data) + case data, ok := <-stderr: + if !ok { + stderr = nil + continue + } + stderrData += string(data) + case data, ok := <-output: + if !ok { + output = nil + continue + } + outputData = append(outputData, data) + case <-time.After(3 * time.Second): + t.Fatal("stream output timed out") + } + } + if stdoutData != "out" { + t.Fatalf("expected streamed stdout %q, got %q", "out", stdoutData) + } + if stderrData != "err" { + t.Fatalf("expected streamed stderr %q, got %q", "err", stderrData) + } + var seenStdout, seenStderr bool + for _, item := range outputData { + switch item.Stream { + case StarCmdOutputStdout: + seenStdout = seenStdout || string(item.Data) == "out" + case StarCmdOutputStderr: + seenStderr = seenStderr || string(item.Data) == "err" + default: + t.Fatalf("unknown output stream %v", item.Stream) + } + } + if !seenStdout || !seenStderr { + t.Fatalf("expected combined output stream to include stdout and stderr, got %#v", outputData) + } +} + +func TestStarCmdStreamNilReturnsClosedChannels(t *testing.T) { + var cmd *StarCmd + select { + case _, ok := <-cmd.StdoutChan(): + if ok { + t.Fatal("nil stdout stream should be closed") + } + case <-time.After(time.Second): + t.Fatal("nil stdout stream should close immediately") + } + select { + case _, ok := <-cmd.StderrChan(): + if ok { + t.Fatal("nil stderr stream should be closed") + } + case <-time.After(time.Second): + t.Fatal("nil stderr stream should close immediately") + } + select { + case _, ok := <-cmd.OutputChan(): + if ok { + t.Fatal("nil output stream should be closed") + } + case <-time.After(time.Second): + t.Fatal("nil output stream should close immediately") + } +} + +func TestStarCmdRedirectOutputWriterKeepsCapture(t *testing.T) { + script := "printf 'out'; printf 'err' 1>&2" + command, args := testCommandArgs(script) + if runtime.GOOS == "windows" { + command, args = testWindowsPowerShellArgs("[Console]::Out.Write('out'); [Console]::Error.Write('err')") + } + cmd, err := Command(command, args...) + if err != nil { + t.Fatal(err) + } + var redirected bytes.Buffer + if err := cmd.RedirectOutput(&redirected); err != nil { + t.Fatal(err) + } + if err := cmd.Start(); err != nil { + t.Fatal(err) + } + <-cmd.Stopped() + if got := redirected.String(); got != "outerr" && got != "errout" { + t.Fatalf("expected redirected stdout/stderr bytes, got %q", got) + } + if out := cmd.AllStdOut(); out != "out" { + t.Fatalf("expected captured stdout %q, got %q", "out", out) + } + if err := cmd.AllStdErr(); err == nil || err.Error() != "err" { + t.Fatalf("expected captured stderr %q, got %v", "err", err) + } +} + +func TestStarCmdRedirectFiles(t *testing.T) { + dir := t.TempDir() + stdoutFile := filepath.Join(dir, "stdout.txt") + stderrFile := filepath.Join(dir, "stderr.txt") + script := "printf 'out'; printf 'err' 1>&2" + command, args := testCommandArgs(script) + if runtime.GOOS == "windows" { + command, args = testWindowsPowerShellArgs("[Console]::Out.Write('out'); [Console]::Error.Write('err')") + } + cmd, err := Command(command, args...) + if err != nil { + t.Fatal(err) + } + if err := cmd.RedirectStdoutFile(stdoutFile); err != nil { + t.Fatal(err) + } + if err := cmd.RedirectStderrFile(stderrFile); err != nil { + t.Fatal(err) + } + if err := cmd.Start(); err != nil { + t.Fatal(err) + } + <-cmd.Stopped() + stdoutData, err := ioutil.ReadFile(stdoutFile) + if err != nil { + t.Fatal(err) + } + stderrData, err := ioutil.ReadFile(stderrFile) + if err != nil { + t.Fatal(err) + } + if string(stdoutData) != "out" { + t.Fatalf("expected stdout file %q, got %q", "out", string(stdoutData)) + } + if string(stderrData) != "err" { + t.Fatalf("expected stderr file %q, got %q", "err", string(stderrData)) + } +} + +func TestStarCmdRedirectStdin(t *testing.T) { + command, args := testCommandArgs("cat") + if runtime.GOOS == "windows" { + command, args = testCommandArgs("more") + } + cmd, err := Command(command, args...) + if err != nil { + t.Fatal(err) + } + if err := cmd.RedirectStdin(strings.NewReader("hello\n")); err != nil { + t.Fatal(err) + } + if err := cmd.Start(); err != nil { + t.Fatal(err) + } + <-cmd.Stopped() + if out := cmd.AllStdOut(); !strings.Contains(out, "hello") { + t.Fatalf("expected redirected stdin in stdout, got %q", out) + } + if err := cmd.AllStdErr(); err != nil { + t.Fatalf("redirected stdin should not create command error, got %v", err) + } + if err := cmd.WriteCmdE("again"); !errors.Is(err, errCommandStdinUnavailable) { + t.Fatalf("expected errCommandStdinUnavailable after stdin redirect, got %v", err) + } +} + +func TestStarCmdRedirectStdinClosesManagedPipe(t *testing.T) { + command, args := testCommandArgs("cat") + if runtime.GOOS == "windows" { + command, args = testCommandArgs("more") + } + cmd, err := Command(command, args...) + if err != nil { + t.Fatal(err) + } + tracker := &closeTrackingWriteCloser{} + cmd.lock.Lock() + cmd.infile = tracker + cmd.inclosed = false + cmd.lock.Unlock() + + if err := cmd.RedirectStdin(strings.NewReader("hello\n")); err != nil { + t.Fatal(err) + } + if !tracker.closed { + t.Fatal("RedirectStdin should close the previously managed stdin pipe") + } + if err := cmd.WriteCmdE("again"); !errors.Is(err, errCommandStdinUnavailable) { + t.Fatalf("expected managed stdin to be unavailable after redirect, got %v", err) + } +} + +func TestStarCmdDetachClosesManagedPipe(t *testing.T) { + command, args := testCommandArgs("exit 0") + cmd, err := Command(command, args...) + if err != nil { + t.Fatal(err) + } + tracker := &closeTrackingWriteCloser{} + cmd.lock.Lock() + original := cmd.infile + cmd.infile = tracker + cmd.inclosed = false + cmd.lock.Unlock() + if original != nil { + _ = original.Close() + } + + if err := cmd.DetachE(); errors.Is(err, ERR_UNSUPPORTED) { + t.Skip(err) + } else if err != nil { + t.Fatal(err) + } + if !tracker.closed { + t.Fatal("DetachE should close the managed stdin pipe") + } + if err := cmd.WriteCmdE("again"); !errors.Is(err, errCommandStdinClosed) { + t.Fatalf("expected detached stdin to be closed, got %v", err) + } +} + +func TestStarCmdRedirectRejectsInvalidState(t *testing.T) { + command, args := testCommandArgs("exit 0") + cmd, err := Command(command, args...) + if err != nil { + t.Fatal(err) + } + if err := cmd.RedirectStdout(nil); !errors.Is(err, errCommandRedirectNil) { + t.Fatalf("expected errCommandRedirectNil, got %v", err) + } + if err := cmd.Start(); err != nil { + t.Fatal(err) + } + if err := cmd.RedirectStdout(&bytes.Buffer{}); !errors.Is(err, errCommandAlreadyStarted) { + t.Fatalf("expected errCommandAlreadyStarted, got %v", err) + } + <-cmd.Stopped() +} + +func TestStarCmdCloseStdinLetsCommandExit(t *testing.T) { + script := "cat" + if runtime.GOOS == "windows" { + script = "more" + } + command, args := testCommandArgs(script) + cmd, err := Command(command, args...) + if err != nil { + t.Fatal(err) + } + if err := cmd.Start(); err != nil { + t.Fatal(err) + } + if err := cmd.WriteCmdE("hello"); err != nil { + t.Fatal(err) + } + if err := cmd.CloseStdinE(); err != nil { + t.Fatal(err) + } + select { + case <-cmd.Stoped(): + case <-time.After(3 * time.Second): + t.Fatal("command should exit after stdin closes") + } + if out := cmd.AllStdOut(); !strings.Contains(out, "hello") { + t.Fatalf("expected echoed stdin, got %q", out) + } + if err := cmd.CloseStdinE(); !errors.Is(err, errCommandStdinClosed) { + t.Fatalf("expected errCommandStdinClosed, got %v", err) + } + if err := cmd.WriteCmdE("again"); !errors.Is(err, errCommandStdinClosed) { + t.Fatalf("expected errCommandStdinClosed after close, got %v", err) + } +} + +func TestStarCmdWriteStdinRawDoesNotAppendNewline(t *testing.T) { + script := "cat" + if runtime.GOOS == "windows" { + script = "more" + } + command, args := testCommandArgs(script) + cmd, err := Command(command, args...) + if err != nil { + t.Fatal(err) + } + if err := cmd.Start(); err != nil { + t.Fatal(err) + } + if err := cmd.WriteStdinStringE("raw"); err != nil { + t.Fatal(err) + } + if err := cmd.WriteStdinE([]byte("-bytes")); err != nil { + t.Fatal(err) + } + if err := cmd.CloseStdinE(); err != nil { + t.Fatal(err) + } + if err := cmd.WaitTimeout(3 * time.Second); err != nil { + t.Fatal(err) + } + if out := cmd.AllStdOut(); !strings.Contains(out, "raw-bytes") { + t.Fatalf("expected raw stdin without inserted newline, got %q", out) + } +} + +func TestStarCmdNilGuards(t *testing.T) { + var cmd *StarCmd + if cmd.IsRunning() { + t.Fatal("nil StarCmd should not be running") + } + if got := cmd.GetPid(); got != -1 { + t.Fatalf("expected nil pid -1, got %d", got) + } + if err := cmd.Release(); !errors.Is(err, errNilCommand) { + t.Fatalf("expected errNilCommand, got %v", err) + } + if err := cmd.SetKeepCaps(); !errors.Is(err, errNilCommand) { + t.Fatalf("expected errNilCommand, got %v", err) + } + if err := cmd.ReleaseE(); !errors.Is(err, errNilCommand) { + t.Fatalf("expected errNilCommand, got %v", err) + } + if err := cmd.DetachE(); !errors.Is(err, errNilCommand) { + t.Fatalf("expected errNilCommand, got %v", err) + } + if err := cmd.SetRunUserE(0, 0, nil); !errors.Is(err, errNilCommand) { + t.Fatalf("expected errNilCommand, got %v", err) + } + if err := cmd.WriteCmdE("noop"); !errors.Is(err, errNilCommand) { + t.Fatalf("expected errNilCommand, got %v", err) + } + if err := cmd.WriteStdinE([]byte("noop")); !errors.Is(err, errNilCommand) { + t.Fatalf("expected errNilCommand, got %v", err) + } + if err := cmd.WriteStdinStringE("noop"); !errors.Is(err, errNilCommand) { + t.Fatalf("expected errNilCommand, got %v", err) + } + if err := cmd.WriteStdinLineE("noop"); !errors.Is(err, errNilCommand) { + t.Fatalf("expected errNilCommand, got %v", err) + } + if err := cmd.Wait(); !errors.Is(err, errNilCommand) { + t.Fatalf("expected errNilCommand, got %v", err) + } + if err := cmd.CloseStdinE(); !errors.Is(err, errNilCommand) { + t.Fatalf("expected errNilCommand, got %v", err) + } +} + +func TestStarCmdReleaseUsesStartLifecycle(t *testing.T) { + command, args := testCommandArgs("exit 0") + cmd, err := Command(command, args...) + if err != nil { + t.Fatal(err) + } + if err := cmd.ReleaseE(); errors.Is(err, ERR_UNSUPPORTED) { + t.Skip(err) + } else if err != nil { + t.Fatal(err) + } + <-cmd.Stoped() + if got := cmd.ExitCode(); got != 0 { + t.Fatalf("expected exit code 0, got %d", got) + } + if err := cmd.ReleaseE(); !errors.Is(err, errCommandAlreadyReleased) { + t.Fatalf("expected errCommandAlreadyReleased, got %v", err) + } +} + +func TestStarCmdReleaseAfterStartKeepsLifecycle(t *testing.T) { + command, args := testCommandArgs("exit 0") + cmd, err := Command(command, args...) + if err != nil { + t.Fatal(err) + } + if err := cmd.Start(); err != nil { + t.Fatal(err) + } + if err := cmd.ReleaseE(); errors.Is(err, ERR_UNSUPPORTED) { + t.Skip(err) + } else if err != nil { + t.Fatal(err) + } + <-cmd.Stoped() + if got := cmd.ExitCode(); got != 0 { + t.Fatalf("expected exit code 0, got %d", got) + } +} + +func TestStarCmdDetachRejectsRepeatedDetach(t *testing.T) { + command, args := testCommandArgs("exit 0") + cmd, err := Command(command, args...) + if err != nil { + t.Fatal(err) + } + if err := cmd.DetachE(); errors.Is(err, ERR_UNSUPPORTED) { + t.Skip(err) + } else if err != nil { + t.Fatal(err) + } + if err := cmd.DetachE(); !errors.Is(err, errCommandAlreadyDetached) { + t.Fatalf("expected errCommandAlreadyDetached, got %v", err) + } + if err := cmd.Start(); !errors.Is(err, errCommandDetached) { + t.Fatalf("expected errCommandDetached, got %v", err) + } +} + +func TestStarCmdDetachPublishesWaitResult(t *testing.T) { + command, args := testCommandArgs("exit 0") + cmd, err := Command(command, args...) + if err != nil { + t.Fatal(err) + } + if err := cmd.DetachE(); errors.Is(err, ERR_UNSUPPORTED) { + t.Skip(err) + } else if err != nil { + t.Fatal(err) + } + if err := cmd.WaitTimeout(0); err != nil { + t.Fatalf("detached command should publish final wait result, got %v", err) + } + ctx, cancel := context.WithCancel(context.Background()) + cancel() + if err := cmd.WaitContext(ctx); err != nil { + t.Fatalf("detached command should beat canceled context, got %v", err) + } + waitErr := make(chan error, 1) + go func() { + waitErr <- cmd.Wait() + }() + select { + case err := <-waitErr: + if err != nil { + t.Fatalf("detached command wait got %v", err) + } + case <-time.After(time.Second): + t.Fatal("detached command wait did not observe final result") + } +} + +func TestStarCmdDetachDoesNotCaptureOutput(t *testing.T) { + script := "printf 'detached'; printf 'err' 1>&2" + if runtime.GOOS == "windows" { + script = "&2" + } + command, args := testCommandArgs(script) + cmd, err := Command(command, args...) + if err != nil { + t.Fatal(err) + } + if err := cmd.DetachE(); errors.Is(err, ERR_UNSUPPORTED) { + t.Skip(err) + } else if err != nil { + t.Fatal(err) + } + <-cmd.Stoped() + if out := cmd.AllStdOut(); out != "" { + t.Fatalf("detached command should not be captured, got stdout %q", out) + } + if err := cmd.AllStdErr(); err != nil { + t.Fatalf("detached command should not capture stderr, got %v", err) + } +} + +func TestStarCmdDetachRejectsStartedCommand(t *testing.T) { + command, args := testCommandArgs("exit 0") + cmd, err := Command(command, args...) + if err != nil { + t.Fatal(err) + } + if err := cmd.Start(); err != nil { + t.Fatal(err) + } + if err := cmd.DetachE(); errors.Is(err, ERR_UNSUPPORTED) { + t.Skip(err) + } else if !errors.Is(err, errCommandAlreadyStarted) { + t.Fatalf("expected errCommandAlreadyStarted, got %v", err) + } + <-cmd.Stoped() +} + +func TestFindProcessByPidCurrentProcess(t *testing.T) { + pid := os.Getpid() + process, err := FindProcessByPid(int64(pid)) + if errors.Is(err, ERR_UNSUPPORTED) { + t.Skip(err) + } + if err != nil { + t.Fatal(err) + } + if process.Pid != int64(pid) { + t.Fatalf("expected pid %d, got %d", pid, process.Pid) + } +} + +func TestStopedNilReturnsClosedChannel(t *testing.T) { + var cmd *StarCmd + select { + case <-cmd.Stoped(): + case <-time.After(time.Second): + t.Fatal("nil Stoped channel should already be closed") + } + select { + case <-cmd.Stopped(): + case <-time.After(time.Second): + t.Fatal("nil Stopped channel should already be closed") } - fmt.Println(cmd.NowAllOutput()) - fmt.Print("all is ") - fmt.Println(cmd.AllOutPut()) - fmt.Println(cmd.ExitCode()) - } diff --git a/process_unix.go b/process_unix.go index 6683f39..580fa42 100644 --- a/process_unix.go +++ b/process_unix.go @@ -1,4 +1,5 @@ -// +build linux darwin +//go:build linux +// +build linux package staros @@ -10,13 +11,19 @@ import ( "os" "os/exec" "path/filepath" + "sort" "strconv" "strings" + "sync/atomic" "syscall" "time" + + "golang.org/x/sys/unix" ) -//FindProcessByName 通过进程名来查询应用信息 +var loadCurrentKeepCaps = currentKeepCaps + +// FindProcessByName 通过进程名来查询应用信息 func FindProcessByName(name string) (datas []Process, err error) { return FindProcess(func(in Process) bool { if name == in.Name { @@ -30,43 +37,11 @@ func FindProcessByName(name string) (datas []Process, err error) { func FindProcess(compare func(Process) bool) (datas []Process, err error) { var name, main string var mainb []byte - var netErr error - var netInfo []NetConn + netSnapshot := loadNetSnapshot(false) paths, err := ioutil.ReadDir("/proc") if err != nil { return } - netInfo, netErr = NetConnections(false, "") - appendNetInfo := func(p *Process) { - if netErr != nil { - p.netErr = netErr - return - } - fds, err := ioutil.ReadDir("/proc/" + strconv.Itoa(int(p.Pid)) + "/fd") - if err != nil && Exists("/proc/"+strconv.Itoa(int(p.Pid))+"/fd") { - p.netErr = err - return - } - for _, fd := range fds { - socket, err := os.Readlink("/proc/" + strconv.Itoa(int(p.Pid)) + "/fd/" + fd.Name()) - if err != nil { - p.netErr = err - return - } - start := strings.Index(socket, "[") - if start < 0 { - continue - } - sid := socket[start+1 : len(socket)-1] - for _, v := range netInfo { - if v.Inode == sid { - v.Pid = p.Pid - v.Process = p - p.netConn = append(p.netConn, v) - } - } - } - } for _, v := range paths { if v.IsDir() && Exists("/proc/"+v.Name()+"/comm") { name, err = readAsString("/proc/" + v.Name() + "/comm") @@ -83,7 +58,7 @@ func FindProcess(compare func(Process) bool) (datas []Process, err error) { if err != nil { tmp.Err = err if compare(tmp) { - appendNetInfo(&tmp) + netSnapshot.appendTo(&tmp) datas = append(datas, tmp) continue } @@ -94,28 +69,22 @@ func FindProcess(compare func(Process) bool) (datas []Process, err error) { tmp.TPid, _ = strconv.ParseInt(data["TracerPid"], 10, 64) uids := splitBySpace(data["Uid"]) gids := splitBySpace(data["Gid"]) - tmp.RUID, _ = strconv.Atoi(uids[0]) - tmp.EUID, _ = strconv.Atoi(uids[1]) - tmp.RGID, _ = strconv.Atoi(gids[0]) - tmp.EGID, _ = strconv.Atoi(gids[1]) - tmp.VmPeak, _ = strconv.ParseInt(splitBySpace(data["VmPeak"])[0], 10, 64) - tmp.VmSize, _ = strconv.ParseInt(splitBySpace(data["VmSize"])[0], 10, 64) - tmp.VmHWM, _ = strconv.ParseInt(splitBySpace(data["VmHWM"])[0], 10, 64) - tmp.VmRSS, _ = strconv.ParseInt(splitBySpace(data["VmRSS"])[0], 10, 64) - tmp.VmLck, _ = strconv.ParseInt(splitBySpace(data["VmLck"])[0], 10, 64) - tmp.VmData, _ = strconv.ParseInt(splitBySpace(data["VmData"])[0], 10, 64) - tmp.VmLck *= 1024 - tmp.VmData *= 1024 - tmp.VmPeak *= 1024 - tmp.VmSize *= 1024 - tmp.VmHWM *= 1024 - tmp.VmRSS *= 1024 + tmp.RUID, _ = atoiField(uids, 0) + tmp.EUID, _ = atoiField(uids, 1) + tmp.RGID, _ = atoiField(gids, 0) + tmp.EGID, _ = atoiField(gids, 1) + tmp.VmPeak = parseProcStatusKB(data["VmPeak"]) + tmp.VmSize = parseProcStatusKB(data["VmSize"]) + tmp.VmHWM = parseProcStatusKB(data["VmHWM"]) + tmp.VmRSS = parseProcStatusKB(data["VmRSS"]) + tmp.VmLck = parseProcStatusKB(data["VmLck"]) + tmp.VmData = parseProcStatusKB(data["VmData"]) } mainb, err = ioutil.ReadFile("/proc/" + v.Name() + "/cmdline") if err != nil { tmp.Err = err if compare(tmp) { - appendNetInfo(&tmp) + netSnapshot.appendTo(&tmp) datas = append(datas, tmp) continue } @@ -129,7 +98,7 @@ func FindProcess(compare func(Process) bool) (datas []Process, err error) { if err != nil { tmp.Err = err if compare(tmp) { - appendNetInfo(&tmp) + netSnapshot.appendTo(&tmp) datas = append(datas, tmp) continue } @@ -144,17 +113,15 @@ func FindProcess(compare func(Process) bool) (datas []Process, err error) { if err != nil { tmp.Err = err if compare(tmp) { - appendNetInfo(&tmp) + netSnapshot.appendTo(&tmp) datas = append(datas, tmp) continue } - } else { - times := splitBySpace(main) - uptime, _ := strconv.ParseInt(strings.TrimSpace(times[21]), 10, 64) - tmp.Uptime = time.Unix(StartTime().Unix()+uptime/100, int64((float64(uptime)/100-float64(uptime/100))*1000000000)) + } else if uptime, ok := procStartTimeFromStat([]byte(main)); ok { + tmp.Uptime = uptime } if compare(tmp) { - appendNetInfo(&tmp) + netSnapshot.appendTo(&tmp) datas = append(datas, tmp) } } @@ -170,38 +137,6 @@ func FindProcessByPid(pid int64) (datas Process, err error) { err = errors.New("Not Found") return } - netInfo, netErr := NetConnections(false, "") - appendNetInfo := func(p *Process) { - if netErr != nil { - p.netErr = netErr - return - } - fds, err := ioutil.ReadDir("/proc/" + strconv.Itoa(int(p.Pid)) + "/fd") - if err != nil && Exists("/proc/"+strconv.Itoa(int(p.Pid))+"/fd") { - p.netErr = err - return - } - for _, fd := range fds { - socket, err := os.Readlink("/proc/" + strconv.Itoa(int(p.Pid)) + "/fd/" + fd.Name()) - if err != nil { - p.netErr = err - return - } - start := strings.Index(socket, "[") - if start < 0 { - continue - } - sid := socket[start+1 : len(socket)-1] - for _, v := range netInfo { - if v.Inode == sid { - v.Pid = p.Pid - v.Process = p - p.netConn = append(p.netConn, v) - } - } - } - } - name, err = readAsString("/proc/" + fmt.Sprint(pid) + "/comm") if err != nil { return @@ -217,23 +152,17 @@ func FindProcessByPid(pid int64) (datas Process, err error) { datas.TPid, _ = strconv.ParseInt(data["TracerPid"], 10, 64) uids := splitBySpace(data["Uid"]) gids := splitBySpace(data["Gid"]) - datas.RUID, _ = strconv.Atoi(uids[0]) - datas.EUID, _ = strconv.Atoi(uids[1]) - datas.RGID, _ = strconv.Atoi(gids[0]) - datas.EGID, _ = strconv.Atoi(gids[1]) - datas.VmPeak, _ = strconv.ParseInt(splitBySpace(data["VmPeak"])[0], 10, 64) - datas.VmSize, _ = strconv.ParseInt(splitBySpace(data["VmSize"])[0], 10, 64) - datas.VmHWM, _ = strconv.ParseInt(splitBySpace(data["VmHWM"])[0], 10, 64) - datas.VmRSS, _ = strconv.ParseInt(splitBySpace(data["VmRSS"])[0], 10, 64) - datas.VmLck, _ = strconv.ParseInt(splitBySpace(data["VmLck"])[0], 10, 64) - datas.VmData, _ = strconv.ParseInt(splitBySpace(data["VmData"])[0], 10, 64) - datas.VmLck *= 1024 - datas.VmData *= 1024 - datas.VmPeak *= 1024 - datas.VmSize *= 1024 - datas.VmHWM *= 1024 - datas.VmRSS *= 1024 - appendNetInfo(&datas) + datas.RUID, _ = atoiField(uids, 0) + datas.EUID, _ = atoiField(uids, 1) + datas.RGID, _ = atoiField(gids, 0) + datas.EGID, _ = atoiField(gids, 1) + datas.VmPeak = parseProcStatusKB(data["VmPeak"]) + datas.VmSize = parseProcStatusKB(data["VmSize"]) + datas.VmHWM = parseProcStatusKB(data["VmHWM"]) + datas.VmRSS = parseProcStatusKB(data["VmRSS"]) + datas.VmLck = parseProcStatusKB(data["VmLck"]) + datas.VmData = parseProcStatusKB(data["VmData"]) + loadNetSnapshot(false).appendTo(&datas) mainb, err = ioutil.ReadFile("/proc/" + fmt.Sprint(pid) + "/cmdline") if err != nil { datas.Err = err @@ -264,12 +193,92 @@ func FindProcessByPid(pid int64) (datas Process, err error) { if err != nil { return } - times := splitBySpace(main) - uptime, _ := strconv.ParseInt(strings.TrimSpace(times[21]), 10, 64) - datas.Uptime = time.Unix(StartTime().Unix()+uptime/100, int64((float64(uptime)/100-float64(uptime/100))*1000000000)) + if uptime, ok := procStartTimeFromStat([]byte(main)); ok { + datas.Uptime = uptime + } return } +func procStartTimeFromStat(content []byte) (time.Time, bool) { + fields := splitProcStat(content) + if len(fields) <= 22 { + return time.Time{}, false + } + startTicks, err := strconv.ParseInt(strings.TrimSpace(fields[22]), 10, 64) + if err != nil { + return time.Time{}, false + } + ticks := int64(clockTicks()) + seconds := startTicks / ticks + nanos := (startTicks % ticks) * int64(time.Second) / ticks + return time.Unix(StartTime().Unix()+seconds, nanos), true +} + +func atoiField(fields []string, index int) (int, error) { + if index < 0 || index >= len(fields) { + return 0, errors.New("field index out of range") + } + return strconv.Atoi(fields[index]) +} + +func parseProcStatusKB(value string) int64 { + fields := splitBySpace(value) + if len(fields) == 0 || fields[0] == "" { + return 0 + } + size, err := strconv.ParseInt(fields[0], 10, 64) + if err != nil { + return 0 + } + return size * 1024 +} + +type netSnapshot struct { + conns []NetConn + err error +} + +func loadNetSnapshot(analysePid bool) netSnapshot { + netInfo, err := NetConnections(analysePid, "") + return netSnapshot{conns: netInfo, err: err} +} + +func appendNetInfo(p *Process, analysePid bool) { + loadNetSnapshot(analysePid).appendTo(p) +} + +func (snapshot netSnapshot) appendTo(p *Process) { + if snapshot.err != nil { + p.netErr = snapshot.err + return + } + fds, err := ioutil.ReadDir("/proc/" + strconv.Itoa(int(p.Pid)) + "/fd") + if err != nil { + if Exists("/proc/" + strconv.Itoa(int(p.Pid)) + "/fd") { + p.netErr = err + } + return + } + for _, fd := range fds { + socket, err := os.Readlink("/proc/" + strconv.Itoa(int(p.Pid)) + "/fd/" + fd.Name()) + if err != nil { + continue + } + start := strings.Index(socket, "[") + if start < 0 { + continue + } + sid := socket[start+1 : len(socket)-1] + for _, v := range snapshot.conns { + if v.Inode == sid { + v.Pid = p.Pid + v.Process = p + p.netConn = append(p.netConn, v) + } + } + } +} + func Daemon(path string, args ...string) (int, error) { cmd := exec.Command(path, args...) cmd.SysProcAttr = &syscall.SysProcAttr{ @@ -302,17 +311,58 @@ func DaemonWithUser(uid, gid uint32, groups []uint32, path string, args ...strin } func (starcli *StarCmd) SetRunUser(uid, gid uint32, groups []uint32) { - starcli.CMD.SysProcAttr = &syscall.SysProcAttr{ - Credential: &syscall.Credential{ - Uid: uid, - Gid: gid, - Groups: groups, - }, - Setsid: true, + _ = starcli.SetRunUserE(uid, gid, groups) +} + +func (starcli *StarCmd) SetRunUserE(uid, gid uint32, groups []uint32) error { + if starcli == nil || starcli.CMD == nil { + return errNilCommand } + if atomic.LoadInt32(&starcli.started) != 0 { + return errCommandAlreadyStarted + } + if starcli.CMD.SysProcAttr == nil { + starcli.CMD.SysProcAttr = &syscall.SysProcAttr{} + } + if starcli.CMD.SysProcAttr.Credential == nil { + starcli.CMD.SysProcAttr.Credential = &syscall.Credential{} + } + starcli.CMD.SysProcAttr.Credential.Uid = uid + starcli.CMD.SysProcAttr.Credential.Gid = gid + starcli.CMD.SysProcAttr.Credential.Groups = append([]uint32(nil), groups...) + starcli.CMD.SysProcAttr.Setsid = true + return nil } func (starcli *StarCmd) Release() error { + return starcli.ReleaseE() +} + +func (starcli *StarCmd) Detach() error { + return starcli.DetachE() +} + +func (starcli *StarCmd) ReleaseE() error { + if starcli == nil || starcli.CMD == nil { + return errNilCommand + } + if !atomic.CompareAndSwapInt32(&starcli.released, 0, 1) { + return errCommandAlreadyReleased + } + if atomic.LoadInt32(&starcli.started) != 0 { + if starcli.CMD.Process == nil { + starcli.lock.Lock() + err := starcli.runerr + starcli.lock.Unlock() + if err != nil { + atomic.StoreInt32(&starcli.released, 0) + return err + } + atomic.StoreInt32(&starcli.released, 0) + return errCommandAlreadyStarted + } + return nil + } if starcli.CMD.SysProcAttr == nil { starcli.CMD.SysProcAttr = &syscall.SysProcAttr{ Setsid: true, @@ -322,27 +372,121 @@ func (starcli *StarCmd) Release() error { starcli.CMD.SysProcAttr.Setsid = true } } - if !starcli.IsRunning() { - if err := starcli.CMD.Start(); err != nil { - return err + if err := starcli.Start(); err != nil { + atomic.StoreInt32(&starcli.released, 0) + return err + } + return nil +} + +func (starcli *StarCmd) DetachE() error { + if starcli == nil || starcli.CMD == nil { + return errNilCommand + } + if !atomic.CompareAndSwapInt32(&starcli.detached, 0, 1) { + return errCommandAlreadyDetached + } + if atomic.LoadInt32(&starcli.started) != 0 { + atomic.StoreInt32(&starcli.detached, 0) + return errCommandAlreadyStarted + } + cmd := exec.Command(starcli.CMD.Path, starcli.CMD.Args[1:]...) + cmd.Dir = starcli.CMD.Dir + cmd.Env = append([]string(nil), starcli.CMD.Env...) + if starcli.CMD.SysProcAttr != nil { + attr := *starcli.CMD.SysProcAttr + cmd.SysProcAttr = &attr + } else { + cmd.SysProcAttr = &syscall.SysProcAttr{ + Setsid: true, } } - time.Sleep(time.Millisecond * 10) - return starcli.CMD.Process.Release() + if !cmd.SysProcAttr.Setsid { + cmd.SysProcAttr.Setsid = true + } + devNull, err := os.OpenFile(os.DevNull, os.O_RDWR, 0) + if err != nil { + atomic.StoreInt32(&starcli.detached, 0) + return err + } + defer devNull.Close() + cmd.Stdin = devNull + cmd.Stdout = devNull + cmd.Stderr = devNull + if err := cmd.Start(); err != nil { + atomic.StoreInt32(&starcli.detached, 0) + return err + } + starcli.CMD.Process = cmd.Process + atomic.StoreInt32(&starcli.started, 1) + starcli.setRunning(false) + starcli.finish() + if err := cmd.Process.Release(); err != nil { + atomic.StoreInt32(&starcli.detached, 0) + return err + } + return nil } func (starcli *StarCmd) SetKeepCaps() error { - _, _, err := syscall.RawSyscall(157 /*SYS PRCTL */, 0x8 /*PR SET KEEPCAPS*/, 1, 0) - if 0 != err { + if err := starcli.ensureConfigurable(); err != nil { return err } + caps, err := loadCurrentKeepCaps() + if err != nil { + return err + } + starcli.lock.Lock() + defer starcli.lock.Unlock() + if starcli.CMD.SysProcAttr == nil { + starcli.CMD.SysProcAttr = &syscall.SysProcAttr{} + } + starcli.CMD.SysProcAttr.AmbientCaps = mergeAmbientCaps(starcli.CMD.SysProcAttr.AmbientCaps, caps) return nil } func SetKeepCaps() error { - _, _, err := syscall.RawSyscall(157 /*SYS PRCTL */, 0x8 /*PR SET KEEPCAPS*/, 1, 0) - if 0 != err { - return err - } - return nil + return unix.Prctl(unix.PR_SET_KEEPCAPS, 1, 0, 0, 0) +} + +func currentKeepCaps() ([]uintptr, error) { + hdr := unix.CapUserHeader{Version: unix.LINUX_CAPABILITY_VERSION_3} + data := [2]unix.CapUserData{} + if err := unix.Capget(&hdr, &data[0]); err != nil { + return nil, err + } + return capsFromCapData(data), nil +} + +func capsFromCapData(data [2]unix.CapUserData) []uintptr { + var caps []uintptr + for index, item := range data { + mask := item.Permitted + for bit := uint(0); bit < 32; bit++ { + if mask&(1<= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') || (r >= '0' && r <= '9') { + if r >= 'a' && r <= 'z' { + r -= 'a' - 'A' + } + builder.WriteRune(r) + underscore = false + continue + } + if !underscore { + builder.WriteByte('_') + underscore = true + } + } + return strings.Trim(builder.String(), "_") +} + +func joinFieldPath(prefix, name string) string { + if prefix == "" { + return name + } + return prefix + "." + name +} + +func fieldNameKey(name string) string { + var builder strings.Builder + for idx, r := range name { + if r >= 'A' && r <= 'Z' { + if idx > 0 { + builder.WriteByte('_') + } + r += 'a' - 'A' + } + builder.WriteRune(r) + } + return builder.String() +} + +func parseTagBool(value string) bool { + switch strings.ToLower(strings.TrimSpace(value)) { + case "1", "t", "true", "y", "yes", "required": + return true + default: + return false + } +} + +func sampleValuesForField(field configField) ([]string, error) { + if field.defaultValue != "" { + return defaultValuesForField(field, field.defaultValue) + } + values, err := outputValuesForField(field) + if err != nil { + return nil, err + } + if field.required && needsRequiredSamplePlaceholder(values) { + return requiredSampleValuesForType(field.value.Type()) + } + if len(values) > 0 { + return values, nil + } + values, err = sampleValuesForType(field.value.Type()) + if err != nil { + return nil, err + } + if field.required && needsRequiredSamplePlaceholder(values) { + return requiredSampleValuesForType(field.value.Type()) + } + return values, nil +} + +func sampleValuesForType(t reflect.Type) ([]string, error) { + for t.Kind() == reflect.Ptr { + t = t.Elem() + } + switch t.Kind() { + case reflect.Slice: + return sampleScalarValue(t.Elem()) + case reflect.Array: + if t.Len() == 0 { + return nil, nil + } + values := make([]string, 0, t.Len()) + for idx := 0; idx < t.Len(); idx++ { + items, err := sampleScalarValue(t.Elem()) + if err != nil { + return nil, err + } + values = append(values, items...) + } + return values, nil + case reflect.Map: + if t.Key().Kind() != reflect.String { + return nil, fmt.Errorf("unsupported map key kind %s", t.Key().Kind()) + } + values, err := sampleScalarValue(t.Elem()) + if err != nil { + return nil, err + } + if len(values) == 0 { + return nil, nil + } + return []string{"key=" + values[0]}, nil + default: + return sampleScalarValue(t) + } +} + +func sampleScalarValue(t reflect.Type) ([]string, error) { + for t.Kind() == reflect.Ptr { + t = t.Elem() + } + value := reflect.Zero(t) + text, err := scalarToString(value) + if err != nil { + return nil, err + } + return []string{text}, nil +} + +func needsRequiredSamplePlaceholder(values []string) bool { + if len(values) == 0 { + return true + } + for _, value := range values { + if strings.TrimSpace(value) != "" { + return false + } + } + return true +} + +func requiredSampleValuesForType(t reflect.Type) ([]string, error) { + for t.Kind() == reflect.Ptr { + t = t.Elem() + } + switch t.Kind() { + case reflect.String: + return []string{"value"}, nil + case reflect.Slice: + return requiredSampleValuesForType(t.Elem()) + case reflect.Array: + if t.Len() == 0 { + return nil, nil + } + values := make([]string, 0, t.Len()) + for idx := 0; idx < t.Len(); idx++ { + items, err := requiredSampleValuesForType(t.Elem()) + if err != nil { + return nil, err + } + values = append(values, items...) + } + return values, nil + case reflect.Map: + if t.Key().Kind() != reflect.String { + return nil, fmt.Errorf("unsupported map key kind %s", t.Key().Kind()) + } + values, err := requiredSampleValuesForType(t.Elem()) + if err != nil { + return nil, err + } + if needsRequiredSamplePlaceholder(values) { + values, err = sampleValuesForType(t.Elem()) + if err != nil { + return nil, err + } + } + if len(values) == 0 { + return nil, nil + } + return []string{"key=" + values[0]}, nil + default: + return sampleValuesForType(t) + } +} + +func defaultValuesForField(field configField, text string) ([]string, error) { + v := field.value + for v.Kind() == reflect.Ptr { + if v.IsNil() { + v = reflect.New(v.Type().Elem()).Elem() + break + } + v = v.Elem() + } + switch v.Kind() { + case reflect.Slice, reflect.Array, reflect.Map: + return splitFieldValue(text, field.split) + } + return []string{text}, nil +} + +func splitFieldValue(text, split string) ([]string, error) { + switch { + case split == "": + return []string{text}, nil + case split == "," || strings.EqualFold(split, "csv"): + return parseCSVList(text) + default: + return parseDelimitedList(text, split), nil + } +} + +func parseCSVList(text string) ([]string, error) { + if text == "" { + return []string{""}, nil + } + reader := csv.NewReader(strings.NewReader(text)) + reader.TrimLeadingSpace = true + record, err := reader.Read() + if err != nil { + return nil, err + } + for idx := range record { + record[idx] = strings.TrimSpace(record[idx]) + } + return record, nil +} + +func parseDelimitedList(text, split string) []string { + if split == "" { + return []string{text} + } + parts := strings.Split(text, split) + for idx := range parts { + parts[idx] = strings.TrimSpace(parts[idx]) + } + return parts +} + +func setConfigField(value reflect.Value, section *Section, key, split string) error { + for value.Kind() == reflect.Ptr { + if value.IsNil() { + value.Set(reflect.New(value.Type().Elem())) + } + value = value.Elem() + } + items, err := expandConfigValues(value, configValuesFromSection(section, key), split) + if err != nil { + return err + } + if len(items) == 0 { + return nil + } + return setConfigValueItems(value, items) +} + +func setConfigValueItems(value reflect.Value, items []configValue) error { + switch value.Kind() { + case reflect.Slice: + return setConfigSlice(value, items) + case reflect.Array: + return setConfigArray(value, items) + case reflect.Map: + return setConfigMap(value, items) + default: + return setConfigScalar(value, items[0]) + } +} + +func expandConfigValues(value reflect.Value, items []configValue, split string) ([]configValue, error) { + if split == "" { + return items, nil + } + switch value.Kind() { + case reflect.Slice, reflect.Array, reflect.Map: + default: + return items, nil + } + out := make([]configValue, 0, len(items)) + for _, item := range items { + if item.noValue { + out = append(out, item) + continue + } + parts, err := splitFieldValue(item.text, split) + if err != nil { + return nil, err + } + for _, part := range parts { + out = append(out, configValue{text: part}) + } + } + return out, nil +} + +func configValuesFromSection(section *Section, key string) []configValue { + return configValuesFromSections([]*Section{section}, key) +} + +func configValuesFromSections(sections []*Section, key string) []configValue { + values := make([]configValue, 0) + for _, section := range sections { + if section == nil { + continue + } + entries := section.EntriesByKey(key) + for _, entry := range entries { + if entry == nil { + continue + } + if entry.NoValue && len(entry.Values) == 0 { + values = append(values, configValue{noValue: true}) + continue + } + for _, value := range entry.Values { + values = append(values, configValue{text: value}) + } + } + } + if len(values) == 0 { + return nil + } + return values +} + +func setConfigSlice(value reflect.Value, items []configValue) error { + out := reflect.MakeSlice(value.Type(), 0, len(items)) + for _, item := range items { + elem := reflect.New(value.Type().Elem()).Elem() + if err := setConfigScalar(elem, item); err != nil { + return err + } + out = reflect.Append(out, elem) + } + value.Set(out) + return nil +} + +func setConfigArray(value reflect.Value, items []configValue) error { + if len(items) != value.Len() { + return fmt.Errorf("array needs %d values, got %d", value.Len(), len(items)) + } + for idx, item := range items { + if err := setConfigScalar(value.Index(idx), item); err != nil { + return err + } + } + return nil +} + +func setConfigMap(value reflect.Value, items []configValue) error { + if value.Type().Key().Kind() != reflect.String { + return fmt.Errorf("unsupported map key kind %s", value.Type().Key().Kind()) + } + out := reflect.MakeMapWithSize(value.Type(), len(items)) + for idx, item := range items { + key, text := parseMapItem(idx, item.text) + elem := reflect.New(value.Type().Elem()).Elem() + mapItem := configValue{text: text, noValue: item.noValue} + if err := setConfigScalar(elem, mapItem); err != nil { + return err + } + out.SetMapIndex(reflect.ValueOf(key), elem) + } + value.Set(out) + return nil +} + +func parseMapItem(idx int, text string) (string, string) { + if at := strings.Index(text, "="); at >= 0 { + return strings.TrimSpace(text[:at]), strings.TrimSpace(text[at+1:]) + } + return strconv.Itoa(idx), text +} + +func outputValuesForField(field configField) ([]string, error) { + value := field.value + for value.Kind() == reflect.Ptr { + if value.IsNil() { + return nil, nil + } + value = value.Elem() + } + switch value.Kind() { + case reflect.Slice, reflect.Array: + values := make([]string, 0, value.Len()) + for idx := 0; idx < value.Len(); idx++ { + text, err := scalarToString(value.Index(idx)) + if err != nil { + return nil, err + } + values = append(values, text) + } + return values, nil + case reflect.Map: + if value.Type().Key().Kind() != reflect.String { + return nil, fmt.Errorf("unsupported map key kind %s", value.Type().Key().Kind()) + } + keys := make([]string, 0, value.Len()) + for _, key := range value.MapKeys() { + keys = append(keys, key.String()) + } + sort.Strings(keys) + values := make([]string, 0, len(keys)) + for _, key := range keys { + text, err := scalarToString(value.MapIndex(reflect.ValueOf(key))) + if err != nil { + return nil, err + } + values = append(values, key+"="+text) + } + return values, nil + default: + text, err := scalarToString(value) + if err != nil { + return nil, err + } + return []string{text}, nil + } +} + +func scalarToString(value reflect.Value) (string, error) { + for value.Kind() == reflect.Ptr { + if value.IsNil() { + return "", nil + } + value = value.Elem() + } + if ok, text, err := textMarshalerString(value); ok || err != nil { + return text, err + } + if value.Type() == reflect.TypeOf(time.Duration(0)) { + return time.Duration(value.Int()).String(), nil + } + switch value.Kind() { + case reflect.String: + return value.String(), nil + case reflect.Bool: + return strconv.FormatBool(value.Bool()), nil + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return strconv.FormatInt(value.Int(), 10), nil + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return strconv.FormatUint(value.Uint(), 10), nil + case reflect.Float32, reflect.Float64: + return strconv.FormatFloat(value.Float(), 'g', -1, value.Type().Bits()), nil + default: + return "", fmt.Errorf("unsupported field kind %s", value.Kind()) + } +} + +func textMarshalerString(value reflect.Value) (bool, string, error) { + if value.CanInterface() { + if marshaler, ok := value.Interface().(encoding.TextMarshaler); ok { + data, err := marshaler.MarshalText() + return true, string(data), err + } + } + if value.CanAddr() { + if marshaler, ok := value.Addr().Interface().(encoding.TextMarshaler); ok { + data, err := marshaler.MarshalText() + return true, string(data), err + } + } + return false, "", nil +} + +func setConfigScalar(value reflect.Value, item configValue) error { + for value.Kind() == reflect.Ptr { + if value.IsNil() { + value.Set(reflect.New(value.Type().Elem())) + } + value = value.Elem() + } + if ok, err := setTextUnmarshaler(value, item.text); ok || err != nil { + if err != nil { + return err + } + return nil + } + if value.Type() == reflect.TypeOf(time.Duration(0)) { + duration, err := time.ParseDuration(item.text) + if err != nil { + return err + } + value.SetInt(int64(duration)) + return nil + } + switch value.Kind() { + case reflect.String: + value.SetString(item.text) + case reflect.Bool: + if item.noValue { + value.SetBool(true) + return nil + } + v, err := strconv.ParseBool(item.text) + if err != nil { + return err + } + value.SetBool(v) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + n, err := strconv.ParseInt(item.text, 10, value.Type().Bits()) + if err != nil { + return err + } + value.SetInt(n) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + n, err := strconv.ParseUint(item.text, 10, value.Type().Bits()) + if err != nil { + return err + } + value.SetUint(n) + case reflect.Float32, reflect.Float64: + n, err := strconv.ParseFloat(item.text, value.Type().Bits()) + if err != nil { + return err + } + value.SetFloat(n) + default: + return fmt.Errorf("unsupported field kind %s", value.Kind()) + } + return nil +} + +func setTextUnmarshaler(value reflect.Value, text string) (bool, error) { + if value.CanAddr() { + if unmarshaler, ok := value.Addr().Interface().(encoding.TextUnmarshaler); ok { + return true, unmarshaler.UnmarshalText([]byte(text)) + } + } + if value.CanInterface() { + if unmarshaler, ok := value.Interface().(encoding.TextUnmarshaler); ok { + return true, unmarshaler.UnmarshalText([]byte(text)) + } + } + return false, nil +} + +func implementsTextUnmarshaler(t reflect.Type) bool { + textUnmarshaler := reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem() + if t.Implements(textUnmarshaler) { + return true + } + return reflect.PtrTo(t).Implements(textUnmarshaler) +} + +func sectionHasPairs(section *Section) bool { + if section == nil { + return false + } + for _, entry := range section.Entries { + if isConfigEntry(entry) { + return true + } + } + return false +} + +func isConfigEntry(entry *Entry) bool { + return entry != nil && entry.Key != "" +} + +func configPath(section, key string) string { + if section == "" { + return key + } + return section + "." + key +} + +func configFieldError(field configField, reason string) error { + return configFieldErrorWithErr(field, reason, nil) +} + +func configFieldErrorWithErr(field configField, reason string, err error) error { + return &ConfigError{ + Section: field.section, + Key: field.key, + Field: field.structField.Name, + Reason: reason, + Err: err, + } +} + +func configKeyError(section, key, reason string, err error) error { + return &ConfigError{ + Section: section, + Key: key, + Reason: reason, + Err: err, + } +} + +func validateConfig(dst interface{}) error { + if validator, ok := dst.(ConfigValidator); ok { + return validator.Validate() + } + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr && !value.IsNil() { + if validator, ok := value.Elem().Interface().(ConfigValidator); ok { + return validator.Validate() + } + } + return nil +} diff --git a/sysconf/csv.go b/sysconf/csv.go index 0980456..7ae62f2 100644 --- a/sysconf/csv.go +++ b/sysconf/csv.go @@ -2,13 +2,15 @@ package sysconf import ( "bytes" + "encoding/csv" "errors" "fmt" "reflect" - "strconv" "strings" ) +var ErrNilCSVValue = errors.New("nil csv value") + type CSV struct { header []string text [][]string @@ -24,238 +26,152 @@ type CSVValue struct { value string } -func ParseCSV(data []byte, hasHeader bool) (csv CSV, err error) { - strData := strings.Split(string(bytes.TrimSpace(data)), "\n") - if len(strData) < 1 { - err = fmt.Errorf("cannot parse data,invalid data format") +func ParseCSV(data []byte, hasHeader bool) (csvData CSV, err error) { + if len(data) == 0 { + return CSV{}, fmt.Errorf("cannot parse data,invalid data format") } - var header []string - var text [][]string + reader := csv.NewReader(bytes.NewReader(data)) + records, err := reader.ReadAll() + if err != nil { + return CSV{}, err + } + if len(records) == 0 { + return CSV{}, fmt.Errorf("cannot parse data,invalid data format") + } + start := 0 if hasHeader { - header = csvAnalyse(strData[0]) - strData = strData[1:] + csvData.header = append([]string(nil), records[0]...) + start = 1 } else { - num := len(csvAnalyse(strData[0])) - for i := 0; i < num; i++ { - header = append(header, strconv.Itoa(i)) + for i := range records[0] { + csvData.header = append(csvData.header, fmt.Sprint(i)) } } - for k, v := range strData { - tmpData := csvAnalyse(v) - if len(tmpData) != len(header) { - err = fmt.Errorf("cannot parse data line %d,got %d values but need %d", k, len(tmpData), len(header)) - return + for _, record := range records[start:] { + if len(record) != len(csvData.header) { + return CSV{}, fmt.Errorf("cannot parse data line,got %d values but need %d", len(record), len(csvData.header)) } - text = append(text, tmpData) + csvData.text = append(csvData.text, append([]string(nil), record...)) } - csv.header = header - csv.text = text - return + return csvData, nil } -func (csv *CSV) Header() []string { - return csv.header -} +func (csvData *CSV) Header() []string { return csvData.header } +func (csvData *CSV) Data() [][]string { return csvData.text } -func (csv *CSV) Data() [][]string { - return csv.text -} - -func (csv *CSV) Row(row int) *CSVRow { - if row >= len(csv.Data()) { +func (csvData *CSV) Row(row int) *CSVRow { + if csvData == nil || row < 0 || row >= len(csvData.text) { return nil } - return &CSVRow{ - header: csv.Header(), - data: csv.Data()[row], - } + return &CSVRow{header: csvData.header, data: csvData.text[row]} } -func (csv *CSVRow) Get(key string) *CSVValue { - for k, v := range csv.header { - if v == key { - return &CSVValue{ - key: key, - value: csv.data[k], - } +func (row *CSVRow) Get(key string) *CSVValue { + if row == nil { + return nil + } + for idx, header := range row.header { + if header == key { + return &CSVValue{key: key, value: row.data[idx]} } } return nil } -func (csv *CSVRow) Col(key int) *CSVValue { - if key >= len(csv.header) { +func (row *CSVRow) Col(key int) *CSVValue { + if row == nil || key < 0 || key >= len(row.header) { return nil } - return &CSVValue{ - key: csv.header[key], - value: csv.data[key], - } + return &CSVValue{key: row.header[key], value: row.data[key]} } -func (csv *CSVRow) Header() []string { - return csv.header -} +func (row *CSVRow) Header() []string { return row.header } -func (csv *CSV) MapData() []map[string]string { +func (csvData *CSV) MapData() []map[string]string { var result []map[string]string - for _, v := range csv.text { - tmp := make(map[string]string) - for k, v2 := range csv.header { - tmp[v2] = v[k] + for _, record := range csvData.text { + item := make(map[string]string, len(csvData.header)) + for idx, header := range csvData.header { + item[header] = record[idx] } - result = append(result, tmp) + result = append(result, item) } return result } -func CsvAnalyse(data string) []string { - return csvAnalyse(data) -} +func CsvAnalyse(data string) []string { return csvAnalyse(data) } func csvAnalyse(data string) []string { - var segStart bool = false - var segReady bool = false - var segSign string = "" - var dotReady bool = false - data = strings.TrimSpace(data) - var result []string - var seg string - for k, v := range []rune(data) { - if k == 0 && v != []rune(`"`)[0] { - dotReady = true - } - if v != []rune(`,`)[0] && dotReady { - segSign = `,` - segStart = true - dotReady = false - if v == []rune(`"`)[0] { - segSign = `"` - continue - } - } - if dotReady && v == []rune(`,`)[0] { - //dotReady = false - result = append(result, "") - continue - } - - if v == []rune(`"`)[0] && segStart { - if !segReady { - segReady = true - continue - } - seg += `"` - segReady = false - continue - } - if segReady && segSign == `"` && segStart { - segReady = false - segStart = false - result = append(result, seg) - segSign = `` - seg = "" - } - - if v == []rune(`"`)[0] && !segStart { - segStart = true - segReady = false - segSign = `"` - continue - } - if v == []rune(`,`)[0] && !segStart { - dotReady = true - } - if v == []rune(`,`)[0] && segStart && segSign == "," { - segStart = false - result = append(result, seg) - dotReady = true - segSign = `` - seg = "" - } - if segStart { - seg = string(append([]rune(seg), v)) - } + reader := csv.NewReader(strings.NewReader(data)) + record, err := reader.Read() + if err != nil { + return []string{} } - if len(data) != 0 && len(result) == 0 && seg == "" { - result = append(result, data) - } else { - result = append(result, seg) - } - - return result + return record } func MarshalCSV(header []string, ins interface{}) ([]byte, error) { - var result [][]string - t := reflect.TypeOf(ins) v := reflect.ValueOf(ins) - if v.Kind() != reflect.Slice && v.Kind() != reflect.Array { - return nil, errors.New("not a Slice or Array") - } - if t.Kind() == reflect.Ptr { - t = t.Elem() + if v.Kind() == reflect.Ptr { + if v.IsNil() { + return nil, ErrNilCSVValue + } v = v.Elem() } - + if v.Kind() != reflect.Slice && v.Kind() != reflect.Array { + return nil, fmt.Errorf("not a Slice or Array") + } + rows := make([][]string, 0, v.Len()) for i := 0; i < v.Len(); i++ { - subT := reflect.TypeOf(v.Index(i).Interface()) - subV := reflect.ValueOf(v.Index(i).Interface()) - if subV.Kind() == reflect.Slice || subV.Kind() == reflect.Array { - if subT.Kind() == reflect.Ptr { - subV = subV.Elem() + item := v.Index(i) + if item.Kind() == reflect.Ptr { + if item.IsNil() { + continue } - var tmp []string - for j := 0; j < subV.Len(); j++ { - tmp = append(tmp, fmt.Sprint(reflect.ValueOf(subV.Index(j)))) - } - result = append(result, tmp) + item = item.Elem() } - if subV.Kind() == reflect.Struct { - var tmp []string - if subT.Kind() == reflect.Ptr { - subV = subV.Elem() + switch item.Kind() { + case reflect.Slice, reflect.Array: + row := make([]string, 0, item.Len()) + for j := 0; j < item.Len(); j++ { + row = append(row, fmt.Sprint(item.Index(j).Interface())) } - for i := 0; i < subV.NumField(); i++ { - tmp = append(tmp, fmt.Sprint(subV.Field(i))) + rows = append(rows, row) + case reflect.Struct: + row := make([]string, 0, item.NumField()) + for j := 0; j < item.NumField(); j++ { + field := item.Field(j) + if !field.CanInterface() { + continue + } + row = append(row, fmt.Sprint(field.Interface())) } - result = append(result, tmp) + rows = append(rows, row) } } - - return buildCSV(header,result) -} - -func buildCSV(header []string, data [][]string) ([]byte, error) { - var result []string - var length int - build := func(slc []string) string { - for k, v := range slc { - if strings.Index(v, `"`) >= 0 { - v = strings.ReplaceAll(v, `"`, `""`) - } - if strings.Index(v,"\n")>=0 { - v=strings.ReplaceAll(v,"\n",`\n`) - } - if strings.Index(v,"\r")>=0 { - v=strings.ReplaceAll(v,"\r",`\r`) - } - v = `"` + v + `"` - slc[k] = v - } - return strings.Join(slc, ",") - } - if len(header) != 0 { - result = append(result, build(header)) - length = len(header) - } else { - length = len(data[0]) - } - for k, v := range data { - if len(v) != length { - return nil, fmt.Errorf("line %d got length %d ,but need %d", k, len(v), length) - } - result = append(result, build(v)) - } - return []byte(strings.Join(result, "\n")), nil + width := 0 + if len(header) > 0 { + width = len(header) + } else if len(rows) > 0 { + width = len(rows[0]) + } + for idx, row := range rows { + if len(row) != width { + return nil, fmt.Errorf("line %d got length %d ,but need %d", idx, len(row), width) + } + } + var buf bytes.Buffer + writer := csv.NewWriter(&buf) + if len(header) > 0 { + if err := writer.Write(header); err != nil { + return nil, err + } + } + for _, row := range rows { + if err := writer.Write(row); err != nil { + return nil, err + } + } + writer.Flush() + return buf.Bytes(), writer.Error() } diff --git a/sysconf/csv_test.go b/sysconf/csv_test.go deleted file mode 100644 index 799e8ba..0000000 --- a/sysconf/csv_test.go +++ /dev/null @@ -1,38 +0,0 @@ -package sysconf - -import ( - "fmt" - "testing" -) - -func Test_csv(t *testing.T) { - //var test Sqlplus - var text=` -姓名,班级,性别,年龄 -张三,"我,不""知道",boy,23 -"里斯","哈哈",girl,23 -` - fmt.Println(csvAnalyse(`请求权,lkjdshck,dsvdsv,"sdvkjsdv,",=dsvdsv,"=,dsvsdv"`)) - a,b:=ParseCSV([]byte(text),true) - fmt.Println(b) - fmt.Println(a.Row(0).Col(3).MustInt()) -} - -type csvtest struct { - A string - B int -} -func Test_Masharl(t *testing.T) { - //var test Sqlplus - /* - var a []csvtest = []csvtest{ - {"lala",1}, - {"haha",34}, - } - */ - var a [][]string - a=append(a,[]string{"a","b","c"}) - a=append(a,[]string{"1",`s"s"d`,"3"}) - b,_:=MarshalCSV([]string{},a) - fmt.Println(string(b)) -} diff --git a/sysconf/csvconvert.go b/sysconf/csvconvert.go deleted file mode 100644 index 54979ef..0000000 --- a/sysconf/csvconvert.go +++ /dev/null @@ -1,120 +0,0 @@ -package sysconf - -import "strconv" - -func (csv *CSVValue)Key()string { - return csv.key -} - -func (csv *CSVValue)Int()(int,error) { - tmp,err:=strconv.Atoi(csv.value) - return tmp,err -} - -func (csv *CSVValue)MustInt()int { - tmp,err:=csv.Int() - if err!=nil { - panic(err) - } - return tmp -} - -func (csv *CSVValue)Int64()(int64,error) { - tmp,err:=strconv.ParseInt(csv.value,10,64) - return tmp,err -} - -func (csv *CSVValue)MustInt64()int64 { - tmp,err:=csv.Int64() - if err!=nil { - panic(err) - } - return tmp -} - -func (csv *CSVValue)Int32()(int32,error) { - tmp,err:=strconv.ParseInt(csv.value,10,32) - return int32(tmp),err -} - -func (csv *CSVValue)MustInt32()int32 { - tmp,err:=csv.Int32() - if err!=nil { - panic(err) - } - return tmp -} - -func (csv *CSVValue)Uint64()(uint64,error) { - tmp,err:=strconv.ParseUint(csv.value,10,64) - return tmp,err -} - -func (csv *CSVValue)MustUint64()uint64 { - tmp,err:=csv.Uint64() - if err!=nil { - panic(err) - } - return tmp -} - -func (csv *CSVValue)Uint32()(uint32,error) { - tmp,err:=strconv.ParseUint(csv.value,10,32) - return uint32(tmp),err -} - -func (csv *CSVValue)MustUint32()uint32 { - tmp,err:=csv.Uint32() - if err!=nil { - panic(err) - } - return tmp -} - -func (csv *CSVValue)String()string { - return csv.value -} - -func (csv *CSVValue)Byte()[]byte { - return []byte(csv.value) -} - - -func (csv *CSVValue)Bool()(bool,error) { - tmp,err:=strconv.ParseBool(csv.value) - return tmp,err -} - -func (csv *CSVValue)MustBool()bool { - tmp,err:=csv.Bool() - if err!=nil { - panic(err) - } - return tmp -} - -func (csv *CSVValue)Float64()(float64,error) { - tmp,err:=strconv.ParseFloat(csv.value,64) - return tmp,err -} - -func (csv *CSVValue)MustFloat64()float64 { - tmp,err:=csv.Float64() - if err!=nil { - panic(err) - } - return tmp -} - -func (csv *CSVValue)Float32()(float32,error) { - tmp,err:=strconv.ParseFloat(csv.value,32) - return float32(tmp),err -} - -func (csv *CSVValue)MustFloat32()float32 { - tmp,err:=csv.Float32() - if err!=nil { - panic(err) - } - return tmp -} \ No newline at end of file diff --git a/sysconf/document.go b/sysconf/document.go new file mode 100644 index 0000000..d73e3b7 --- /dev/null +++ b/sysconf/document.go @@ -0,0 +1,1192 @@ +package sysconf + +import ( + "bytes" + "errors" + "fmt" + "io" + "os" + "path/filepath" + "sort" + "strconv" + "strings" + "sync" +) + +var ( + ErrDocumentClosed = errors.New("document is nil") + ErrSectionNotFound = errors.New("section not found") + ErrKeyNotFound = errors.New("key not found") +) + +type Document struct { + mu sync.RWMutex + sections []*Section + sectionIndex map[string][]*Section + SectionOpen string + SectionClose string + Assign string + AssignDelimiters []string + CommentHeads []string + AllowInline bool + InlineCommentRequiresSpace bool + AllowNoValue bool + AllowMulti bool + AllowContinuation bool + TrimSpace bool + CaseSensitive bool + Strict bool +} + +type Section struct { + Name string + HeaderComment string + Raw string + Newline string + Entries []*Entry + entryIndex map[string][]*Entry + CaseSensitive bool + rawName string + parsedHeaderComment string +} + +type Entry struct { + Key string + Values []string + Comment string + Raw string + NoValue bool + Delimiter string + Newline string + kind lineKind + + parsedKey string + parsedValues []string + parsedComment string + parsedNoValue bool + parsedDelimiter string +} + +type parsedLine struct { + kind lineKind + raw string + newline string + key string + value string + comment string + sectionName string + noValue bool + delimiter string + line int +} + +type lineKind int + +const ( + lineEmpty lineKind = iota + lineComment + lineSection + linePair + lineRaw +) + +func NewDocument() *Document { + return &Document{ + sectionIndex: make(map[string][]*Section), + SectionOpen: "[", + SectionClose: "]", + Assign: "=", + AssignDelimiters: []string{"=", ":"}, + CommentHeads: []string{"#", ";"}, + AllowInline: true, + InlineCommentRequiresSpace: true, + AllowNoValue: true, + AllowMulti: true, + AllowContinuation: true, + TrimSpace: true, + } +} + +type ParseError struct { + Line int + Column int + Message string +} + +func (e *ParseError) Error() string { + if e == nil { + return "" + } + if e.Column > 0 { + return fmt.Sprintf("sysconf: parse error at line %d, column %d: %s", e.Line, e.Column, e.Message) + } + return fmt.Sprintf("sysconf: parse error at line %d: %s", e.Line, e.Message) +} + +func normalize(text string, caseSensitive bool) string { + if caseSensitive { + return text + } + return strings.ToLower(text) +} + +func (d *Document) Parse(data []byte) error { + if d == nil { + return ErrDocumentClosed + } + d.mu.Lock() + defer d.mu.Unlock() + + d.sections = nil + d.sectionIndex = make(map[string][]*Section) + + current := d.appendSection("", "", "", "") + lines := splitSourceLines(string(data)) + for idx := 0; idx < len(lines); idx++ { + source := lines[idx] + raw := source.text + text := source.text + newline := source.newline + if d.AllowContinuation { + var err error + raw, text, newline, idx, err = d.consumeContinuation(lines, idx) + if err != nil { + return err + } + } + line, err := parseRawLine(raw, text, d, source.line) + if err != nil { + return err + } + line.newline = newline + switch line.kind { + case lineSection: + current = d.appendSection(line.sectionName, line.comment, line.raw, line.newline) + case linePair, lineComment, lineEmpty, lineRaw: + if current == nil { + current = d.ensureSection("") + } + current.addParsed(line) + } + } + return nil +} + +type sourceLine struct { + text string + newline string + line int +} + +func splitSourceLines(text string) []sourceLine { + if text == "" { + return nil + } + lines := make([]sourceLine, 0) + start := 0 + line := 1 + for idx := 0; idx < len(text); idx++ { + switch text[idx] { + case '\n': + lines = append(lines, sourceLine{text: text[start:idx], newline: "\n", line: line}) + start = idx + 1 + line++ + case '\r': + newline := "\r" + end := idx + 1 + if end < len(text) && text[end] == '\n' { + newline = "\r\n" + end++ + } + lines = append(lines, sourceLine{text: text[start:idx], newline: newline, line: line}) + start = end + idx = end - 1 + line++ + } + } + if start < len(text) { + lines = append(lines, sourceLine{text: text[start:], line: line}) + } + return lines +} + +func (d *Document) consumeContinuation(lines []sourceLine, idx int) (string, string, string, int, error) { + source := lines[idx] + raw := source.text + text := source.text + newline := source.newline + for hasContinuation(text) { + if idx+1 >= len(lines) { + if d.Strict { + return raw, text, newline, idx, &ParseError{Line: source.line, Column: continuationColumn(text), Message: "line continuation has no following line"} + } + return raw, text, newline, idx, nil + } + next := lines[idx+1] + raw += newline + next.text + text = trimContinuation(text) + strings.TrimLeft(next.text, " \t") + newline = next.newline + idx++ + } + return raw, text, newline, idx, nil +} + +func hasContinuation(text string) bool { + trimmed := strings.TrimRight(text, " \t") + if !strings.HasSuffix(trimmed, `\`) { + return false + } + count := 0 + for idx := len(trimmed) - 1; idx >= 0 && trimmed[idx] == '\\'; idx-- { + count++ + } + return count%2 == 1 +} + +func trimContinuation(text string) string { + trimmed := strings.TrimRight(text, " \t") + return trimmed[:len(trimmed)-1] +} + +func continuationColumn(text string) int { + trimmed := strings.TrimRight(text, " \t") + return len(trimmed) +} + +func parseRawLine(raw, text string, doc *Document, lineNo int) (parsedLine, error) { + trimmed := text + if doc.TrimSpace { + trimmed = strings.TrimSpace(text) + } + if trimmed == "" { + return parsedLine{kind: lineEmpty, raw: raw, line: lineNo}, nil + } + for _, head := range doc.CommentHeads { + if strings.HasPrefix(trimmed, head) { + return parsedLine{kind: lineComment, raw: raw, comment: strings.TrimSpace(strings.TrimPrefix(trimmed, head)), line: lineNo}, nil + } + } + if doc.SectionOpen != "" && doc.SectionClose != "" && strings.HasPrefix(trimmed, doc.SectionOpen) { + line, err := parseSectionLine(raw, trimmed, doc, lineNo) + if err != nil { + return parsedLine{}, err + } + return line, nil + } + delimiter, idx := findAssignDelimiter(trimmed, doc.assignDelimiters()) + if idx < 0 { + if doc.AllowNoValue { + key, comment, err := splitInlineComment(trimmed, doc, lineNo) + if err != nil { + return parsedLine{}, err + } + key = strings.TrimSpace(key) + if key == "" && doc.Strict { + return parsedLine{}, &ParseError{Line: lineNo, Column: 1, Message: "empty key"} + } + return parsedLine{kind: linePair, raw: raw, key: key, comment: comment, noValue: true, delimiter: doc.writeDelimiter(), line: lineNo}, nil + } + if doc.Strict { + return parsedLine{}, &ParseError{Line: lineNo, Column: 1, Message: "missing key/value delimiter"} + } + return parsedLine{kind: lineRaw, raw: raw, line: lineNo}, nil + } + key := strings.TrimSpace(trimmed[:idx]) + if key == "" && doc.Strict { + return parsedLine{}, &ParseError{Line: lineNo, Column: 1, Message: "empty key"} + } + right := strings.TrimSpace(trimmed[idx+len(delimiter):]) + value, comment, err := splitInlineComment(right, doc, lineNo) + if err != nil { + return parsedLine{}, err + } + value, err = parseValue(value, doc, lineNo) + if err != nil { + return parsedLine{}, err + } + return parsedLine{kind: linePair, raw: raw, key: key, value: value, comment: comment, delimiter: delimiter, line: lineNo}, nil +} + +func parseSectionLine(raw, trimmed string, doc *Document, lineNo int) (parsedLine, error) { + end := strings.Index(trimmed[len(doc.SectionOpen):], doc.SectionClose) + if end < 0 { + if doc.Strict { + return parsedLine{}, &ParseError{Line: lineNo, Column: len(trimmed), Message: "section header is missing closing marker"} + } + return parsedLine{kind: lineRaw, raw: raw, line: lineNo}, nil + } + end += len(doc.SectionOpen) + name := strings.TrimSpace(trimmed[len(doc.SectionOpen):end]) + if name == "" && doc.Strict { + return parsedLine{}, &ParseError{Line: lineNo, Column: len(doc.SectionOpen) + 1, Message: "empty section name"} + } + tail := strings.TrimSpace(trimmed[end+len(doc.SectionClose):]) + comment := "" + if tail != "" { + var err error + rest, inlineComment, err := splitInlineComment(tail, doc, lineNo) + if err != nil { + return parsedLine{}, err + } + if strings.TrimSpace(rest) != "" { + if doc.Strict { + return parsedLine{}, &ParseError{Line: lineNo, Column: end + len(doc.SectionClose) + 1, Message: "unexpected text after section header"} + } + return parsedLine{kind: lineRaw, raw: raw, line: lineNo}, nil + } + comment = inlineComment + } + return parsedLine{kind: lineSection, raw: raw, sectionName: name, comment: comment, line: lineNo}, nil +} + +func (d *Document) assignDelimiters() []string { + if d == nil { + return nil + } + if len(d.AssignDelimiters) > 0 { + out := make([]string, 0, len(d.AssignDelimiters)) + for _, delimiter := range d.AssignDelimiters { + if delimiter != "" { + out = append(out, delimiter) + } + } + sort.SliceStable(out, func(i, j int) bool { + return len(out[i]) > len(out[j]) + }) + return out + } + if d.Assign == "" { + return nil + } + return []string{d.Assign} +} + +func (d *Document) writeDelimiter() string { + if d == nil { + return "=" + } + if d.Assign != "" { + return d.Assign + } + if len(d.AssignDelimiters) > 0 { + for _, delimiter := range d.AssignDelimiters { + if delimiter != "" { + return delimiter + } + } + } + return "=" +} + +func findAssignDelimiter(text string, delimiters []string) (string, int) { + if len(delimiters) == 0 { + return "", -1 + } + var quote byte + escaped := false + for idx := 0; idx < len(text); idx++ { + ch := text[idx] + if escaped { + escaped = false + continue + } + if ch == '\\' { + escaped = true + continue + } + if ch == '"' || ch == '\'' { + if quote == 0 { + quote = ch + continue + } + if quote == ch { + quote = 0 + continue + } + } + if quote != 0 { + continue + } + for _, delimiter := range delimiters { + if delimiter != "" && strings.HasPrefix(text[idx:], delimiter) { + return delimiter, idx + } + } + } + return "", -1 +} + +func splitInlineComment(text string, doc *Document, lineNo int) (string, string, error) { + if !doc.AllowInline { + return strings.TrimSpace(text), "", nil + } + var quote byte + escaped := false + for i := 0; i < len(text); i++ { + ch := text[i] + if escaped { + escaped = false + continue + } + if ch == '\\' { + escaped = true + continue + } + if ch == '"' || ch == '\'' { + if quote == 0 { + quote = ch + continue + } + if quote == ch { + quote = 0 + continue + } + } + if quote != 0 { + continue + } + for _, head := range doc.CommentHeads { + if head == "" || !strings.HasPrefix(text[i:], head) { + continue + } + if doc.InlineCommentRequiresSpace && i > 0 && text[i-1] != ' ' && text[i-1] != '\t' { + continue + } + return strings.TrimSpace(text[:i]), strings.TrimSpace(text[i+len(head):]), nil + } + } + if quote != 0 && doc.Strict { + return "", "", &ParseError{Line: lineNo, Column: len(text), Message: "unterminated quoted value"} + } + return strings.TrimSpace(text), "", nil +} + +func parseValue(value string, doc *Document, lineNo int) (string, error) { + value = strings.TrimSpace(value) + if value == "" { + return "", nil + } + quote := value[0] + if quote != '"' && quote != '\'' { + return value, nil + } + escaped := false + end := -1 + for idx := 1; idx < len(value); idx++ { + ch := value[idx] + if escaped { + escaped = false + continue + } + if ch == '\\' { + escaped = true + continue + } + if ch == quote { + end = idx + break + } + } + if end < 0 { + if doc.Strict { + return "", &ParseError{Line: lineNo, Column: 1, Message: "unterminated quoted value"} + } + return value, nil + } + if strings.TrimSpace(value[end+1:]) != "" { + if doc.Strict { + return "", &ParseError{Line: lineNo, Column: end + 2, Message: "unexpected text after quoted value"} + } + return value, nil + } + return unescapeQuotedValue(value[1:end], quote), nil +} + +func unescapeQuotedValue(value string, quote byte) string { + var builder strings.Builder + builder.Grow(len(value)) + escaped := false + for idx := 0; idx < len(value); idx++ { + ch := value[idx] + if !escaped { + if ch == '\\' { + escaped = true + continue + } + builder.WriteByte(ch) + continue + } + switch ch { + case 'n': + builder.WriteByte('\n') + case 'r': + builder.WriteByte('\r') + case 't': + builder.WriteByte('\t') + case '\\': + builder.WriteByte('\\') + case '"', '\'': + if ch == quote { + builder.WriteByte(ch) + } else { + builder.WriteByte('\\') + builder.WriteByte(ch) + } + default: + builder.WriteByte('\\') + builder.WriteByte(ch) + } + escaped = false + } + if escaped { + builder.WriteByte('\\') + } + return builder.String() +} + +func (p parsedLine) toEntry() *Entry { + switch p.kind { + case lineComment: + return &Entry{Raw: p.raw, Comment: p.comment, Newline: p.newline, kind: p.kind} + case lineEmpty: + return &Entry{Raw: p.raw, Newline: p.newline, kind: p.kind} + case lineRaw: + return &Entry{Raw: p.raw, Newline: p.newline, kind: p.kind} + } + entry := &Entry{ + Key: p.key, + Comment: p.comment, + Raw: p.raw, + NoValue: p.noValue, + Delimiter: p.delimiter, + Newline: p.newline, + kind: p.kind, + } + if !p.noValue { + entry.Values = []string{p.value} + } + entry.rememberParsed() + return entry +} + +func (e *Entry) rememberParsed() { + if e == nil { + return + } + e.parsedKey = e.Key + e.parsedValues = append([]string(nil), e.Values...) + e.parsedComment = e.Comment + e.parsedNoValue = e.NoValue + e.parsedDelimiter = e.Delimiter +} + +func (e *Entry) parsedPairUnchanged() bool { + if e == nil || e.kind != linePair || e.Raw == "" { + return false + } + if e.Key != e.parsedKey || e.Comment != e.parsedComment || e.NoValue != e.parsedNoValue || e.Delimiter != e.parsedDelimiter { + return false + } + if len(e.Values) != len(e.parsedValues) { + return false + } + for idx := range e.Values { + if e.Values[idx] != e.parsedValues[idx] { + return false + } + } + return true +} + +func (d *Document) ensureSection(name string) *Section { + d.rebuildSectionIndexLocked() + key := normalize(name, d.CaseSensitive) + if sections := d.sectionIndex[key]; len(sections) > 0 { + return sections[0] + } + return d.appendSection(name, "", "", "\n") +} + +func (d *Document) rebuildSectionIndexLocked() { + if d.sectionIndex == nil { + d.sectionIndex = make(map[string][]*Section) + } + for key := range d.sectionIndex { + delete(d.sectionIndex, key) + } + for _, section := range d.sections { + if section == nil { + continue + } + key := normalize(section.Name, d.CaseSensitive) + d.sectionIndex[key] = append(d.sectionIndex[key], section) + } +} + +func (d *Document) appendSection(name, comment, raw, newline string) *Section { + section := &Section{ + Name: name, + HeaderComment: comment, + Raw: raw, + Newline: newline, + CaseSensitive: d.CaseSensitive, + entryIndex: make(map[string][]*Entry), + rawName: name, + parsedHeaderComment: comment, + } + if section.Newline == "" && raw == "" { + section.Newline = "\n" + } + key := normalize(name, d.CaseSensitive) + d.sections = append(d.sections, section) + d.sectionIndex[key] = append(d.sectionIndex[key], section) + return section +} + +func (s *Section) addParsed(line parsedLine) { + switch line.kind { + case linePair: + entry := line.toEntry() + s.addEntry(entry) + case lineComment, lineEmpty, lineRaw: + s.Entries = append(s.Entries, line.toEntry()) + } +} + +func (s *Section) addEntry(entry *Entry) { + if s.entryIndex == nil { + s.entryIndex = make(map[string][]*Entry) + } + key := normalize(entry.Key, s.CaseSensitive) + s.Entries = append(s.Entries, entry) + s.entryIndex[key] = append(s.entryIndex[key], entry) +} + +func (d *Document) Section(name string) *Section { + if d == nil { + return nil + } + d.mu.Lock() + defer d.mu.Unlock() + d.rebuildSectionIndexLocked() + sections := d.sectionIndex[normalize(name, d.CaseSensitive)] + if len(sections) == 0 { + return nil + } + return sections[0] +} + +func (d *Document) SectionsByName(name string) []*Section { + if d == nil { + return nil + } + d.mu.Lock() + defer d.mu.Unlock() + d.rebuildSectionIndexLocked() + sections := d.sectionIndex[normalize(name, d.CaseSensitive)] + return append([]*Section(nil), sections...) +} + +func (d *Document) Sections() []*Section { + if d == nil { + return nil + } + d.mu.RLock() + defer d.mu.RUnlock() + return append([]*Section(nil), d.sections...) +} + +func (s *Section) Entry(key string) *Entry { + if s == nil { + return nil + } + entries := s.entryIndex[normalize(key, s.CaseSensitive)] + if len(entries) == 0 { + return nil + } + return entries[0] +} + +func (s *Section) EntriesByKey(key string) []*Entry { + if s == nil { + return nil + } + entries := s.entryIndex[normalize(key, s.CaseSensitive)] + return append([]*Entry(nil), entries...) +} + +func (s *Section) Keys() []string { + if s == nil { + return nil + } + keys := make([]string, 0, len(s.entryIndex)) + for _, entries := range s.entryIndex { + if len(entries) == 0 { + continue + } + keys = append(keys, entries[0].Key) + } + sort.Strings(keys) + return keys +} + +func (s *Section) Get(key string) string { + if entry := s.Entry(key); entry != nil && len(entry.Values) > 0 { + return entry.Values[0] + } + return "" +} + +func (s *Section) GetAll(key string) []string { + if s == nil { + return nil + } + entries := s.EntriesByKey(key) + if len(entries) == 0 { + return nil + } + values := make([]string, 0) + for _, entry := range entries { + values = append(values, entry.Values...) + } + return values +} + +func (s *Section) Int(key string) int { + v, _ := strconv.Atoi(s.Get(key)) + return v +} + +func (s *Section) Int64(key string) int64 { + v, _ := strconv.ParseInt(s.Get(key), 10, 64) + return v +} + +func (s *Section) Int32(key string) int32 { + v, _ := strconv.ParseInt(s.Get(key), 10, 32) + return int32(v) +} + +func (s *Section) Float64(key string) float64 { + v, _ := strconv.ParseFloat(s.Get(key), 64) + return v +} + +func (s *Section) Float32(key string) float32 { + v, _ := strconv.ParseFloat(s.Get(key), 32) + return float32(v) +} + +func (s *Section) Bool(key string) bool { + v, _ := strconv.ParseBool(s.Get(key)) + return v +} + +func (s *Section) SetBool(key string, value bool, comment string) error { + return s.Set(key, strconv.FormatBool(value), comment) +} + +func (s *Section) SetFloat64(key string, prec int, value float64, comment string) error { + return s.Set(key, strconv.FormatFloat(value, 'f', prec, 64), comment) +} + +func (s *Section) SetFloat32(key string, prec int, value float32, comment string) error { + return s.Set(key, strconv.FormatFloat(float64(value), 'f', prec, 32), comment) +} + +func (s *Section) SetUint64(key string, value uint64, comment string) error { + return s.Set(key, strconv.FormatUint(value, 10), comment) +} + +func (s *Section) SetInt64(key string, value int64, comment string) error { + return s.Set(key, strconv.FormatInt(value, 10), comment) +} + +func (s *Section) SetInt32(key string, value int32, comment string) error { + return s.Set(key, strconv.FormatInt(int64(value), 10), comment) +} + +func (s *Section) SetInt(key string, value int, comment string) error { + return s.Set(key, strconv.Itoa(value), comment) +} + +type countingWriter struct { + w io.Writer + n int64 +} + +func (w *countingWriter) Write(data []byte) (int, error) { + n, err := w.w.Write(data) + w.n += int64(n) + return n, err +} + +func (d *Document) WriteTo(w io.Writer) (int64, error) { + if d == nil { + return 0, ErrDocumentClosed + } + d.mu.RLock() + defer d.mu.RUnlock() + counting := &countingWriter{w: w} + for _, section := range d.sections { + if section == nil { + continue + } + if section.Raw != "" && section.Name == section.rawName && section.HeaderComment == section.parsedHeaderComment { + if err := writeLine(counting, section.Raw, section.Newline, true); err != nil { + return counting.n, err + } + } else if section.Name != "" && d.SectionOpen != "" { + if err := writeLine(counting, d.formatSectionHeader(section), section.Newline, false); err != nil { + return counting.n, err + } + } + for _, entry := range section.Entries { + if err := d.writeEntry(counting, entry); err != nil { + return counting.n, err + } + } + } + return counting.n, nil +} + +func writeLine(w io.Writer, text, newline string, preserveNoNewline bool) error { + if text != "" { + if _, err := io.WriteString(w, text); err != nil { + return err + } + } + if newline != "" { + _, err := io.WriteString(w, newline) + return err + } + if preserveNoNewline { + return nil + } + _, err := io.WriteString(w, "\n") + return err +} + +func (d *Document) writeEntry(w io.Writer, entry *Entry) error { + if entry == nil { + return nil + } + if entry.kind != linePair && entry.Raw != "" { + return writeLine(w, entry.Raw, entry.Newline, true) + } + if entry.parsedPairUnchanged() { + return writeLine(w, entry.Raw, entry.Newline, true) + } + newline := entry.Newline + if newline == "" { + newline = "\n" + } + delimiter := entry.Delimiter + if delimiter == "" { + delimiter = d.Assign + } + if delimiter == "" { + delimiter = "=" + } + if entry.NoValue || len(entry.Values) == 0 { + text := entry.Key + if entry.Comment != "" && d.AllowInline && len(d.CommentHeads) > 0 { + text += " " + d.CommentHeads[0] + entry.Comment + return writeLine(w, text, newline, false) + } + if err := writeLine(w, text, newline, false); err != nil { + return err + } + if entry.Comment != "" && len(d.CommentHeads) > 0 { + return writeLine(w, d.CommentHeads[0]+entry.Comment, newline, false) + } + return nil + } + value := d.formatValue(entry.Values[0]) + text := entry.Key + delimiter + value + if entry.Comment != "" && d.AllowInline && len(d.CommentHeads) > 0 { + text += " " + d.CommentHeads[0] + entry.Comment + } + if err := writeLine(w, text, newline, false); err != nil { + return err + } + if d.AllowMulti && len(entry.Values) > 1 { + for _, extra := range entry.Values[1:] { + extraValue := d.formatValue(extra) + if err := writeLine(w, entry.Key+delimiter+extraValue, newline, false); err != nil { + return err + } + } + } + return nil +} + +func (d *Document) formatSectionHeader(section *Section) string { + text := d.SectionOpen + section.Name + d.SectionClose + if section.HeaderComment != "" && d.AllowInline && len(d.CommentHeads) > 0 { + text += " " + d.CommentHeads[0] + section.HeaderComment + } + return text +} + +func (d *Document) formatValue(value string) string { + if !d.valueNeedsQuotes(value) { + return value + } + return quoteValue(value) +} + +func (d *Document) valueNeedsQuotes(value string) bool { + if value == "" { + return false + } + if strings.TrimSpace(value) != value || strings.ContainsAny(value, "\r\n\t") { + return true + } + if value[0] == '"' || value[0] == '\'' { + return true + } + if d != nil && d.AllowInline { + for _, head := range d.CommentHeads { + if head != "" && strings.Contains(value, head) { + return true + } + } + } + return false +} + +func quoteValue(value string) string { + var builder strings.Builder + builder.Grow(len(value) + 2) + builder.WriteByte('"') + for idx := 0; idx < len(value); idx++ { + switch value[idx] { + case '\\': + builder.WriteString(`\\`) + case '"': + builder.WriteString(`\"`) + case '\n': + builder.WriteString(`\n`) + case '\r': + builder.WriteString(`\r`) + case '\t': + builder.WriteString(`\t`) + default: + builder.WriteByte(value[idx]) + } + } + builder.WriteByte('"') + return builder.String() +} + +func (d *Document) Bytes() []byte { + var buf bytes.Buffer + _, _ = d.WriteTo(&buf) + return buf.Bytes() +} + +func (d *Document) Save(path string) error { + if d == nil { + return ErrDocumentClosed + } + return os.WriteFile(path, d.Bytes(), 0o644) +} + +func (d *Document) SaveAtomic(path string) error { + if d == nil { + return ErrDocumentClosed + } + return writeFileAtomic(path, d.Bytes(), 0o644) +} + +func writeFileAtomic(path string, data []byte, defaultPerm os.FileMode) error { + perm := defaultPerm + if info, err := os.Stat(path); err == nil { + perm = info.Mode().Perm() + } else if !os.IsNotExist(err) { + return err + } + dir := filepath.Dir(path) + base := filepath.Base(path) + tmp, err := os.CreateTemp(dir, "."+base+".tmp-*") + if err != nil { + return err + } + tmpPath := tmp.Name() + keepTmp := false + defer func() { + if !keepTmp { + _ = os.Remove(tmpPath) + } + }() + if _, err := tmp.Write(data); err != nil { + _ = tmp.Close() + return err + } + if err := tmp.Chmod(perm); err != nil { + _ = tmp.Close() + return err + } + if err := tmp.Sync(); err != nil { + _ = tmp.Close() + return err + } + if err := tmp.Close(); err != nil { + return err + } + if err := os.Rename(tmpPath, path); err != nil { + return err + } + keepTmp = true + syncParentDir(dir) + return nil +} + +func syncParentDir(dir string) { + f, err := os.Open(dir) + if err != nil { + return + } + _ = f.Sync() + _ = f.Close() +} + +func (s *Section) Exist(key string) bool { + return s.Entry(key) != nil +} + +func (s *Section) Comment(key string) string { + if entry := s.Entry(key); entry != nil { + return entry.Comment + } + return "" +} + +func (s *Section) SetComment(key, comment string) error { + entry := s.Entry(key) + if entry == nil { + return ErrKeyNotFound + } + entry.Comment = comment + return nil +} + +func (s *Section) Set(key, value, comment string) error { + return s.SetAll(key, []string{value}, comment) +} + +func (s *Section) SetAll(key string, values []string, comment string) error { + if s == nil { + return ErrSectionNotFound + } + if s.entryIndex == nil { + s.entryIndex = make(map[string][]*Entry) + } + normalized := normalize(key, s.CaseSensitive) + entries := s.entryIndex[normalized] + if len(entries) == 0 { + entry := &Entry{Key: key, Values: append([]string(nil), values...), Comment: comment, Newline: s.Newline} + if len(values) == 0 { + entry.NoValue = true + } + s.addEntry(entry) + return nil + } + first := entries[0] + first.Values = append([]string(nil), values...) + first.Comment = comment + first.NoValue = len(values) == 0 + if len(entries) > 1 { + s.entryIndex[normalized] = []*Entry{first} + filtered := s.Entries[:0] + for _, entry := range s.Entries { + if normalize(entry.Key, s.CaseSensitive) != normalized { + filtered = append(filtered, entry) + continue + } + if entry == first { + filtered = append(filtered, entry) + } + } + s.Entries = filtered + } + return nil +} + +func (s *Section) AddValue(key, value, comment string) error { + if s == nil { + return ErrSectionNotFound + } + if s.entryIndex == nil { + s.entryIndex = make(map[string][]*Entry) + } + normalized := normalize(key, s.CaseSensitive) + entry := &Entry{Key: key, Values: []string{value}, Comment: comment, Newline: s.Newline} + s.Entries = append(s.Entries, entry) + s.entryIndex[normalized] = append(s.entryIndex[normalized], entry) + return nil +} + +func (s *Section) Delete(key string) error { + if s == nil { + return ErrSectionNotFound + } + normalized := normalize(key, s.CaseSensitive) + if len(s.entryIndex[normalized]) == 0 { + return ErrKeyNotFound + } + delete(s.entryIndex, normalized) + filtered := s.Entries[:0] + for _, entry := range s.Entries { + if normalize(entry.Key, s.CaseSensitive) == normalized { + continue + } + filtered = append(filtered, entry) + } + s.Entries = filtered + return nil +} + +func (s *Section) DeleteValue(key, value string) error { + normalized := normalize(key, s.CaseSensitive) + entries := s.entryIndex[normalized] + if len(entries) == 0 { + return ErrKeyNotFound + } + kept := make([]*Entry, 0, len(entries)) + for _, entry := range entries { + filtered := entry.Values[:0] + for _, item := range entry.Values { + if item == value { + continue + } + filtered = append(filtered, item) + } + if len(filtered) == 0 { + continue + } + entry.Values = filtered + kept = append(kept, entry) + } + if len(kept) == 0 { + delete(s.entryIndex, normalized) + } else { + s.entryIndex[normalized] = kept + } + filteredEntries := s.Entries[:0] + for _, entry := range s.Entries { + if normalize(entry.Key, s.CaseSensitive) != normalized { + filteredEntries = append(filteredEntries, entry) + continue + } + for _, keptEntry := range kept { + if entry == keptEntry { + filteredEntries = append(filteredEntries, entry) + break + } + } + } + s.Entries = filteredEntries + return nil +} diff --git a/sysconf/example_migration_test.go b/sysconf/example_migration_test.go new file mode 100644 index 0000000..c85d457 --- /dev/null +++ b/sysconf/example_migration_test.go @@ -0,0 +1,27 @@ +package sysconf + +import ( + "fmt" + "strings" +) + +func ExampleNewIni_migration() { + ini := NewIni() + _ = ini.Parse([]byte("[app]\nport=8080\nfeature=alpha\nfeature=beta\n")) + + app := ini.Section("app") + _ = app.SetInt("port", 9090, "") + _ = app.SetAll("feature", []string{"stable", "audit"}, "") + + fmt.Println(ini.Get("app", "port")) + fmt.Println(ini.GetAll("app", "feature")) + fmt.Println(strings.TrimSpace(string(ini.Build()))) + + // Output: + // 9090 + // [stable audit] + // [app] + // port=9090 + // feature=stable + // feature=audit +} diff --git a/sysconf/ini.go b/sysconf/ini.go new file mode 100644 index 0000000..024bdc5 --- /dev/null +++ b/sysconf/ini.go @@ -0,0 +1,445 @@ +package sysconf + +import ( + "errors" + "fmt" + "os" + "reflect" + "sort" +) + +type Ini struct { + *Document +} + +type IniProfile func(*Ini) + +func NewIni() *Ini { + return &Ini{Document: NewDocument()} +} + +func NewIniWithProfiles(profiles ...IniProfile) *Ini { + ini := NewIni() + for _, profile := range profiles { + if profile != nil { + profile(ini) + } + } + return ini +} + +func DefaultINIProfile() IniProfile { + return func(ini *Ini) { + if ini == nil { + return + } + ini.Document = NewDocument() + } +} + +func StrictINIProfile() IniProfile { + return func(ini *Ini) { + if ini == nil { + return + } + if ini.Document == nil { + ini.Document = NewDocument() + } + ini.Strict = true + ini.AllowNoValue = false + } +} + +func LinuxConfProfile(equal string) IniProfile { + return func(ini *Ini) { + if ini == nil { + return + } + if ini.Document == nil { + ini.Document = NewDocument() + } + ini.SectionOpen = "" + ini.SectionClose = "" + ini.CommentHeads = []string{"#"} + if equal != "" { + ini.Assign = equal + } + ini.AssignDelimiters = []string{ini.Assign} + } +} + +func (i *Ini) ApplyProfile(profile IniProfile) *Ini { + if profile != nil { + profile(i) + } + return i +} + +func NewSysConf(equal string) *Ini { + ini := NewIni() + if equal != "" { + ini.Assign = equal + } + ini.AssignDelimiters = []string{ini.Assign} + return ini +} + +func NewLinuxConf(equal string) *Ini { + return NewIniWithProfiles(LinuxConfProfile(equal)) +} + +func (i *Ini) Parse(data []byte) error { + if i == nil || i.Document == nil { + return ErrDocumentClosed + } + return i.Document.Parse(data) +} + +func (i *Ini) ParseFromFile(path string) error { + data, err := os.ReadFile(path) + if err != nil { + return err + } + return i.Parse(data) +} + +func (i *Ini) Build() []byte { + if i == nil || i.Document == nil { + return nil + } + return i.Document.Bytes() +} + +func (i *Ini) Save(path string) error { + if i == nil || i.Document == nil { + return ErrDocumentClosed + } + return i.Document.Save(path) +} + +func (i *Ini) SaveAtomic(path string) error { + if i == nil || i.Document == nil { + return ErrDocumentClosed + } + return i.Document.SaveAtomic(path) +} + +func (i *Ini) Section(name string) *Section { + if i == nil || i.Document == nil { + return nil + } + return i.Document.Section(name) +} + +func (i *Ini) Sections(name string) []*Section { + if i == nil || i.Document == nil { + return nil + } + return i.Document.SectionsByName(name) +} + +func (i *Ini) AddSection(name string) *Section { + if i == nil || i.Document == nil { + return nil + } + i.Document.mu.Lock() + defer i.Document.mu.Unlock() + return i.Document.appendSection(name, "", "", "\n") +} + +func (i *Ini) DeleteSection(name string) bool { + if i == nil || i.Document == nil { + return false + } + i.Document.mu.Lock() + defer i.Document.mu.Unlock() + i.Document.rebuildSectionIndexLocked() + normalized := normalize(name, i.CaseSensitive) + sections := i.Document.sectionIndex[normalized] + if len(sections) == 0 { + return false + } + delete(i.Document.sectionIndex, normalized) + filtered := i.Document.sections[:0] + for _, section := range i.Document.sections { + if normalize(section.Name, i.CaseSensitive) == normalized { + continue + } + filtered = append(filtered, section) + } + i.Document.sections = filtered + return true +} + +func (i *Ini) Get(section, key string) string { + for _, s := range i.Sections(section) { + if s != nil && s.Exist(key) { + return s.Get(key) + } + } + return "" +} + +func (i *Ini) GetAll(section, key string) []string { + sections := i.Sections(section) + if len(sections) == 0 { + return nil + } + values := make([]string, 0) + for _, s := range sections { + if s == nil { + continue + } + values = append(values, s.GetAll(key)...) + } + if len(values) == 0 { + return nil + } + return values +} + +func (i *Ini) Has(section, key string) bool { + for _, s := range i.Sections(section) { + if s != nil && s.Exist(key) { + return true + } + } + return false +} + +func (i *Ini) Set(section, key, value string) { + if i == nil || i.Document == nil { + return + } + i.Document.mu.Lock() + s := i.Document.ensureSection(section) + i.Document.mu.Unlock() + if s != nil { + _ = s.Set(key, value, "") + } +} + +func (i *Ini) AddValue(section, key, value string) { + if i == nil || i.Document == nil { + return + } + i.Document.mu.Lock() + s := i.Document.ensureSection(section) + i.Document.mu.Unlock() + if s != nil { + _ = s.AddValue(key, value, "") + } +} + +func (i *Ini) Delete(section, key string) bool { + if s := i.Section(section); s != nil { + return s.Delete(key) == nil + } + return false +} + +func (i *Ini) SectionsMap() map[string][]*Section { + if i == nil || i.Document == nil { + return nil + } + i.Document.mu.Lock() + defer i.Document.mu.Unlock() + i.Document.rebuildSectionIndexLocked() + out := make(map[string][]*Section, len(i.Document.sectionIndex)) + for name, sections := range i.Document.sectionIndex { + out[name] = append([]*Section(nil), sections...) + } + return out +} + +func (i *Ini) Unmarshal(dst interface{}) error { + return bindINI(i, dst) +} + +func (i *Ini) Marshal(src interface{}) ([]byte, error) { + tmp := newIniLike(i) + if err := marshalINI(tmp, src); err != nil { + return nil, err + } + return tmp.Build(), nil +} + +func bindINI(i *Ini, dst interface{}) error { + if dst == nil { + return errors.New("destination is nil") + } + v := reflect.ValueOf(dst) + if v.Kind() != reflect.Ptr || v.IsNil() { + return errors.New("destination must be a non-nil pointer") + } + v = v.Elem() + if v.Kind() != reflect.Struct { + return errors.New("destination must point to a struct") + } + return bindStruct(i, v, "") +} + +func bindStruct(i *Ini, v reflect.Value, inheritedSection string) error { + t := v.Type() + for idx := 0; idx < t.NumField(); idx++ { + field := t.Field(idx) + value := v.Field(idx) + if !value.CanSet() { + continue + } + section := field.Tag.Get("seg") + key := field.Tag.Get("key") + if section == "" { + section = inheritedSection + } + if key == "-" { + continue + } + if isNestedConfigStruct(value, key) { + nested := value + for nested.Kind() == reflect.Ptr { + if nested.IsNil() { + nested.Set(reflect.New(nested.Type().Elem())) + } + nested = nested.Elem() + } + if err := bindStruct(i, nested, section); err != nil { + return err + } + continue + } + if key == "" { + continue + } + items := configValuesFromSections(i.Sections(section), key) + if len(items) == 0 { + continue + } + if err := setINIField(value, items); err != nil { + return err + } + } + return nil +} + +func setINIField(value reflect.Value, items []configValue) error { + return setConfigValueItems(value, items) +} + +func marshalINI(dst *Ini, src interface{}) error { + v := reflect.ValueOf(src) + if !v.IsValid() { + return errors.New("nil source") + } + if v.Kind() == reflect.Ptr { + if v.IsNil() { + return errors.New("nil source") + } + v = v.Elem() + } + if v.Kind() != reflect.Struct { + return errors.New("source must be struct") + } + return marshalStruct(dst, v, "") +} + +func marshalStruct(dst *Ini, v reflect.Value, inheritedSection string) error { + t := v.Type() + for i := 0; i < t.NumField(); i++ { + field := t.Field(i) + fv := v.Field(i) + if !fv.CanInterface() { + continue + } + section := field.Tag.Get("seg") + if section == "" { + section = inheritedSection + } + key := field.Tag.Get("key") + comment := field.Tag.Get("comment") + if key == "-" { + continue + } + if nested, ok := nestedConfigValueForWrite(fv, key); ok { + if nested.IsValid() { + if err := marshalStruct(dst, nested, section); err != nil { + return err + } + } + continue + } + if key == "" { + continue + } + if err := setINIValue(dst, section, key, fv); err != nil { + return err + } + if comment != "" { + sec := dst.Section(section) + if sec == nil { + sec = dst.AddSection(section) + } + if sec != nil { + _ = sec.SetComment(key, comment) + } + } + } + return nil +} + +func marshalSection(dst *Ini, section string, value reflect.Value) error { + return marshalStruct(dst, value, section) +} + +func setINIValue(dst *Ini, section, key string, value reflect.Value) error { + for value.Kind() == reflect.Ptr { + if value.IsNil() { + return nil + } + value = value.Elem() + } + switch value.Kind() { + case reflect.Slice, reflect.Array: + if value.Type().Elem().Kind() != reflect.String { + dst.Set(section, key, fmt.Sprint(value.Interface())) + return nil + } + sec := dst.Section(section) + if sec == nil { + sec = dst.AddSection(section) + } + if sec == nil { + return ErrDocumentClosed + } + values := make([]string, 0, value.Len()) + for idx := 0; idx < value.Len(); idx++ { + values = append(values, value.Index(idx).String()) + } + return sec.SetAll(key, values, "") + case reflect.Map: + if value.Type().Key().Kind() != reflect.String || value.Type().Elem().Kind() != reflect.String { + dst.Set(section, key, fmt.Sprint(value.Interface())) + return nil + } + keys := make([]string, 0, value.Len()) + for _, mapKey := range value.MapKeys() { + keys = append(keys, mapKey.String()) + } + sort.Strings(keys) + values := make([]string, 0, len(keys)) + for _, mapKey := range keys { + values = append(values, mapKey+"="+value.MapIndex(reflect.ValueOf(mapKey)).String()) + } + sec := dst.Section(section) + if sec == nil { + sec = dst.AddSection(section) + } + if sec == nil { + return ErrDocumentClosed + } + return sec.SetAll(key, values, "") + default: + dst.Set(section, key, fmt.Sprint(value.Interface())) + return nil + } +} diff --git a/sysconf/sysconf.go b/sysconf/sysconf.go deleted file mode 100644 index ffeb2c1..0000000 --- a/sysconf/sysconf.go +++ /dev/null @@ -1,855 +0,0 @@ -package sysconf - -import ( - "bytes" - "errors" - "fmt" - "io/ioutil" - "reflect" - "strconv" - "strings" - "sync" -) - -type SysConf struct { - Data []*SysSegment - segmap map[string]int64 - segId int64 - HaveSegMent bool //是否有节这个概念 - SegStart string - SegEnd string - CommentFlag []string //评论标识符,如# - EqualFlag string //赋值标识符,如= - ValueFlag string //值标识符,如" - EscapeFlag string //转义字符 - CommentCR bool //评论是否能与value同一行,true不行,false可以 - SpaceStr string //美化符号 - lock sync.RWMutex -} - -type SysSegment struct { - Name string - //nodeMap - Comment string - NodeData []*SysNode - nodeId int64 - nodeMap map[string]int64 - lock sync.RWMutex -} - -type SysNode struct { - Key string - Value []string - Comment string - NoValue bool - lock sync.RWMutex -} - -func NewIni() *SysConf { - ini := NewSysConf("=") - ini.CommentCR = true - ini.CommentFlag = []string{"#", ";"} - ini.HaveSegMent = true - ini.SegStart = "[" - ini.SegEnd = "]" - ini.SpaceStr = " " - ini.EscapeFlag = "\\" - return ini -} - -func NewSysConf(EqualFlag string) *SysConf { - syscnf := new(SysConf) - syscnf.EqualFlag = EqualFlag - return syscnf -} - -// NewLinuxConf sysctl.conf like file -func NewLinuxConf(EqualFlag string) *SysConf { - syscnf := new(SysConf) - syscnf.EqualFlag = EqualFlag - syscnf.HaveSegMent = false - syscnf.CommentCR = true - syscnf.CommentFlag = []string{"#"} - return syscnf -} - -func (syscfg *SysConf) ParseFromFile(filepath string) error { - data, err := ioutil.ReadFile(filepath) - if err != nil { - return err - } - return syscfg.Parse(data) -} - -// Parse 生成INI文件结构 -func (syscfg *SysConf) Parse(data []byte) error { - syscfg.lock.Lock() - defer syscfg.lock.Unlock() - if syscfg.HaveSegMent && (syscfg.SegStart == "" || syscfg.SegEnd == "") { - return errors.New("SegMent Start or End Flag Not Allowed!") - } - if !syscfg.CommentCR { - return syscfg.parseNOCRComment(data) - } - return syscfg.parseCRComment(data) -} - -func (syscfg *SysConf) parseNOCRComment(data []byte) error { //允许comment在同一行 - data = bytes.TrimSpace(data) - dataLists := bytes.Split(data, []byte("\n")) - seg := new(SysSegment) - seg.nodeMap = make(map[string]int64) - syscfg.segmap = make(map[string]int64) - if syscfg.HaveSegMent { - seg.Name = "unnamed" - } - syscfg.segId = 0 - var node *SysNode - for _, v1 := range dataLists { - var ( - isSegStart bool = false - isEscape bool = false - isEqual bool = false - isComment bool = false - tsuMo string = "" - ) - cowStr := strings.TrimSpace(string(v1)) - for i := 0; i < len(cowStr); i++ { - runeStr := cowStr[i : i+1] //当前字符,rune扫描 - if runeStr == syscfg.EscapeFlag && (!isEscape) { - isEscape = true - continue - } - if runeStr == syscfg.SegStart && (!isEscape) { - - isSegStart = true - continue - } - if runeStr == syscfg.SegEnd && (!isEscape) { - isSegStart = false - //New segment start from here - if seg.Name == "unnamed" && len(seg.NodeData) == 0 { - seg.Name = tsuMo - tsuMo = "" - continue - } - syscfg.segmap[seg.Name] = syscfg.segId - syscfg.segId++ - syscfg.Data = append(syscfg.Data, seg) - seg = new(SysSegment) - seg.nodeMap = make(map[string]int64) - seg.Name = tsuMo - tsuMo = "" - continue - } - if isSegStart { - tsuMo += runeStr - if isEscape { - isEscape = false - } - continue - } - if syscfg.EqualFlag == runeStr && (!isEscape) && (!isEqual) { - key := strings.TrimSpace(tsuMo) - if val, ok := seg.nodeMap[key]; ok { - node = seg.NodeData[val] - } else { - node = new(SysNode) - node.Key = key - seg.nodeMap[node.Key] = seg.nodeId - seg.nodeId++ - seg.NodeData = append(seg.NodeData, node) - } - tsuMo = "" - isEqual = true - if syscfg.ValueFlag != "" { - nokoriStr := strings.TrimSpace(cowStr[i+1:]) - isFound := false - isValue := false - for k4, v4 := range nokoriStr { - if string([]rune{v4}) == syscfg.ValueFlag { - isValue = !isValue - } - if SliceIn(syscfg.CommentFlag, string([]rune{v4})) && !isValue { - val := nokoriStr[:k4] - isFound = true - startFinder := strings.Index(val, syscfg.ValueFlag) - endFinder := strings.LastIndex(val, syscfg.ValueFlag) - if !((startFinder == -1 || endFinder == -1) || (endFinder-startFinder <= 0)) { - node.Value = append(node.Value, strings.TrimSpace(val[startFinder+1:endFinder])) - } - node.Comment = nokoriStr[k4+1:] + "\n" - } - } - if !isFound { - startFinder := strings.Index(nokoriStr, syscfg.ValueFlag) - endFinder := strings.LastIndex(nokoriStr, syscfg.ValueFlag) - if (startFinder == -1 || endFinder == -1) || (endFinder-startFinder <= 0) { - break - } - node.Value = append(node.Value, strings.TrimSpace(nokoriStr[startFinder+1:endFinder])) - } - break - } - continue - } - if SliceIn(syscfg.CommentFlag, runeStr) && (!isEscape) { - isComment = true - if seg.nodeId == 0 { - seg.Comment += strings.TrimSpace(cowStr[i+1:]) + "\n" - break - } - if tsuMo != "" { - node.Value = append(node.Value, strings.TrimSpace(tsuMo)) - tsuMo = "" - } - node.Comment += strings.TrimSpace(cowStr[i+1:]) + "\n" - break - } - isEscape = false - tsuMo += runeStr - } - if isEqual && tsuMo != "" { - node.Value = append(node.Value, strings.TrimSpace(tsuMo)) - } - if !isEqual && tsuMo != "" && !isComment { - node = new(SysNode) - node.Key = tsuMo - seg.nodeMap[node.Key] = seg.nodeId - seg.nodeId++ - seg.NodeData = append(seg.NodeData, node) - node.NoValue = true - } - } - if seg != nil { - syscfg.segmap[seg.Name] = syscfg.segId - syscfg.segId++ - syscfg.Data = append(syscfg.Data, seg) - } - return nil -} - -func (syscfg *SysConf) parseCRComment(data []byte) error { //不允许comment在同一行 - data = bytes.TrimSpace(data) - dataLists := bytes.Split(data, []byte("\n")) - seg := new(SysSegment) - seg.nodeMap = make(map[string]int64) - syscfg.segmap = make(map[string]int64) - if syscfg.HaveSegMent { - seg.Name = "unnamed" - } - syscfg.segId = 0 - var node *SysNode - for _, v1 := range dataLists { - var ( - isSegStart bool = false - isEscape bool = false - isEqual bool = false - isComment bool = false - tsuMo string = "" - ) - cowStr := strings.TrimSpace(string(v1)) - for i := 0; i < len(cowStr); i++ { - runeStr := cowStr[i : i+1] //当前字符,rune扫描 - if runeStr == syscfg.EscapeFlag && (!isEscape) { - isEscape = true - continue - } - if runeStr == syscfg.SegStart && (!isEscape) { - - isSegStart = true - continue - } - if runeStr == syscfg.SegEnd && (!isEscape) { - isSegStart = false - //New segment start from here - if seg.Name == "unnamed" && len(seg.NodeData) == 0 { - seg.Name = tsuMo - tsuMo = "" - break - } - syscfg.segmap[seg.Name] = syscfg.segId - syscfg.segId++ - syscfg.Data = append(syscfg.Data, seg) - seg = new(SysSegment) - seg.nodeMap = make(map[string]int64) - seg.Name = tsuMo - tsuMo = "" - break - } - if isSegStart { - tsuMo += runeStr - if isEscape { - isEscape = false - } - continue - } - if syscfg.EqualFlag == runeStr && (!isEscape) && (!isEqual) { - key := strings.TrimSpace(tsuMo) - if val, ok := seg.nodeMap[key]; ok { - node = seg.NodeData[val] - } else { - node = new(SysNode) - node.Key = key - seg.nodeMap[node.Key] = seg.nodeId - seg.nodeId++ - seg.NodeData = append(seg.NodeData, node) - } - tsuMo = "" - isEqual = true - if syscfg.ValueFlag == "" { - node.Value = append(node.Value, TrimEscape(strings.TrimSpace(cowStr[i+1:]), syscfg.EscapeFlag)) - } else { - nokoriStr := strings.TrimSpace(cowStr[i+1:]) - startFinder := strings.Index(nokoriStr, syscfg.ValueFlag) - endFinder := strings.LastIndex(nokoriStr, syscfg.ValueFlag) - if (startFinder == -1 || endFinder == -1) || (endFinder-startFinder <= 0) { - break - } - node.Value = append(node.Value, strings.TrimSpace(nokoriStr[startFinder+1:endFinder])) - } - break - } - if SliceIn(syscfg.CommentFlag, runeStr) && (!isEscape) { - isComment = true - tsuMo = "" - if seg.nodeId == 0 { - seg.Comment += strings.TrimSpace(cowStr[i+1:]) + "\n" - break - } - node.Comment += strings.TrimSpace(cowStr[i+1:]) + "\n" - break - } - isEscape = false - tsuMo += runeStr - } - if !isEqual && tsuMo != "" && !isComment { - node = new(SysNode) - node.Key = tsuMo - seg.nodeMap[node.Key] = seg.nodeId - seg.nodeId++ - seg.NodeData = append(seg.NodeData, node) - node.NoValue = true - } - } - if seg != nil { - syscfg.segmap[seg.Name] = syscfg.segId - syscfg.segId++ - syscfg.Data = append(syscfg.Data, seg) - } - return nil -} - -func (syscfg *SysConf) Build() []byte { - syscfg.lock.Lock() - defer syscfg.lock.Unlock() - var outPut string - for _, v := range syscfg.Data { - if v == nil { - continue - } - if syscfg.HaveSegMent { - outPut += syscfg.SegStart + v.Name + syscfg.SegEnd + "\n" - } - if v.Comment != "" { - v.Comment = v.Comment[:len(v.Comment)-1] - comment := strings.Split(v.Comment, "\n") - for _, vc := range comment { - if vc != "" { - outPut += syscfg.CommentFlag[0] + vc + "\n" - } else { - outPut += "\n" - } - } - } - for _, v2 := range v.NodeData { - if v2 == nil { - continue - } - if v2.NoValue { - outPut += v2.Key + "\n" - } else { - for _, v3 := range v2.Value { - if syscfg.ValueFlag != "" { - outPut += v2.Key + syscfg.SpaceStr + syscfg.EqualFlag + syscfg.SpaceStr + syscfg.ValueFlag + v3 + syscfg.ValueFlag + "\n" - } else { - outPut += v2.Key + syscfg.SpaceStr + syscfg.EqualFlag + syscfg.SpaceStr + syscfg.addEscape(v3) + "\n" - } - } - if len(v2.Value) == 0 { - outPut += v2.Key + syscfg.SpaceStr + syscfg.EqualFlag + "\n" - } - if v2.Comment != "" { - v2.Comment = v2.Comment[:len(v2.Comment)-1] - comment := strings.Split(v2.Comment, "\n") - for _, vc := range comment { - if vc != "" { - outPut += syscfg.CommentFlag[0] + vc + "\n" - } else { - outPut += "\n" - } - } - } - } - } - } - return []byte(outPut) -} - -func (syscfg *SysConf) addEscape(str string) string { - str = strings.ReplaceAll(str, syscfg.EscapeFlag, syscfg.EscapeFlag+syscfg.EscapeFlag) - str = strings.ReplaceAll(str, syscfg.EqualFlag, syscfg.EscapeFlag+syscfg.EqualFlag) - str = strings.ReplaceAll(str, syscfg.SegStart, syscfg.EscapeFlag+syscfg.SegStart) - str = strings.ReplaceAll(str, syscfg.SegEnd, syscfg.EscapeFlag+syscfg.SegEnd) - for _, v := range syscfg.CommentFlag { - str = strings.ReplaceAll(str, v, syscfg.EscapeFlag+v) - } - return str -} - -func (syscfg *SysConf) Reverse() { - syscfg.lock.Lock() - defer syscfg.lock.Unlock() - for _, v := range syscfg.Data { - if v == nil { - continue - } - var ( - NodeData []*SysNode - nodeId int64 - nodeMap map[string]int64 - ) - nodeMap = make(map[string]int64) - for _, v2 := range v.NodeData { - if v2 == nil { - continue - } - for _, v3 := range v2.Value { - var node *SysNode - if val, ok := nodeMap[v3]; ok { - node = NodeData[val] - } else { - node = new(SysNode) - node.Key = v3 - NodeData = append(NodeData, node) - nodeMap[v3] = nodeId - nodeId++ - } - node.Value = append(node.Value, strings.TrimSpace(v2.Key)) - node.Comment += v2.Comment - } - v.NodeData = NodeData - v.nodeId = nodeId - v.nodeMap = nodeMap - } - } -} - -func TrimEscape(text, escape string) string { - var isEscape bool = false - var outPut []rune - if escape == "" { - return text - } - text = strings.TrimSpace(text) - for _, v := range text { - if v == []rune(escape)[0] && !isEscape { - isEscape = true - continue - } - outPut = append(outPut, v) - } - return string(outPut) -} - -func SliceIn(slice interface{}, data interface{}) bool { - typed := reflect.ValueOf(slice) - if typed.Kind() == reflect.Slice || typed.Kind() == reflect.Array { - for i := 0; i < typed.Len(); i++ { - if typed.Index(i).Interface() == data { - return true - } - } - } - return false -} - -// Unmarshal 输出结果到结构体中 -func (cfg *SysConf) Unmarshal(ins interface{}) error { - var structSet func(t reflect.Type, v reflect.Value, oriSeg string) error - t := reflect.TypeOf(ins) - v := reflect.ValueOf(ins).Elem() - if v.Kind() != reflect.Struct { - return errors.New("Not a Struct") - } - if t.Kind() != reflect.Ptr || !v.CanSet() { - return errors.New("Cannot Write!") - } - t = t.Elem() - structSet = func(t reflect.Type, v reflect.Value, oriSeg string) error { - for i := 0; i < t.NumField(); i++ { - tp := t.Field(i) - vl := v.Field(i) - if !vl.CanSet() { - continue - } - if vl.Type().Kind() == reflect.Struct { - structSet(vl.Type(), vl, tp.Tag.Get("seg")) - continue - } - seg := tp.Tag.Get("seg") - key := tp.Tag.Get("key") - if key != "" && seg == "" && cfg.HaveSegMent { - seg = "unnamed" - } - if oriSeg != "" { - seg = oriSeg - } - if seg == "" || key == "" { - continue - } - if _, ok := cfg.segmap[seg]; !ok { - continue - } - segs := cfg.Data[cfg.segmap[seg]] - if segs.Get(key) == "" { - continue - } - switch vl.Kind() { - case reflect.String: - vl.SetString(segs.Get(key)) - case reflect.Int, reflect.Int32, reflect.Int64: - vl.SetInt(segs.Int64(key)) - case reflect.Float32, reflect.Float64: - vl.SetFloat(segs.Float64(key)) - case reflect.Bool: - vl.SetBool(segs.Bool(key)) - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - vl.SetUint(uint64(segs.Int64(key))) - default: - continue - } - } - return nil - } - return structSet(t, v, "") -} - -// Marshal 输出结果到结构体中 -func (cfg *SysConf) Marshal(ins interface{}) ([]byte, error) { - var structSet func(t reflect.Type, v reflect.Value, oriSeg string) - t := reflect.TypeOf(ins) - v := reflect.ValueOf(ins) - if v.Kind() != reflect.Struct { - return nil, errors.New("Not a Struct") - } - if t.Kind() == reflect.Ptr { - t = t.Elem() - v = v.Elem() - } - structSet = func(t reflect.Type, v reflect.Value, oriSeg string) { - for i := 0; i < t.NumField(); i++ { - var seg, key, comment string = "", "", "" - tp := t.Field(i) - vl := v.Field(i) - if vl.Type().Kind() == reflect.Struct { - structSet(vl.Type(), vl, tp.Tag.Get("seg")) - continue - } - seg = tp.Tag.Get("seg") - key = tp.Tag.Get("key") - comment = tp.Tag.Get("comment") - if oriSeg != "" { - seg = oriSeg - } - if seg == "" || key == "" { - continue - } - if _, ok := cfg.segmap[seg]; !ok { - cfg.AddSeg(seg) - } - cfg.Seg(seg).Set(key, fmt.Sprint(vl), comment) - } - - } - structSet(t, v, "") - return cfg.Build(), nil -} - -func (syscfg *SysConf) Seg(name string) *SysSegment { - if _, ok := syscfg.segmap[name]; !ok { - return nil - } - seg := syscfg.Data[syscfg.segmap[name]] - return seg -} - -func (syscfg *SysConf) AddSeg(name string) *SysSegment { - syscfg.lock.Lock() - defer syscfg.lock.Unlock() - if _, ok := syscfg.segmap[name]; !ok { - newseg := new(SysSegment) - newseg.Name = name - newseg.nodeMap = make(map[string]int64) - syscfg.Data = append(syscfg.Data, newseg) - syscfg.segId++ - if syscfg.segmap == nil { - syscfg.segId = 0 - syscfg.segmap = make(map[string]int64) - } - syscfg.segmap[newseg.Name] = syscfg.segId - return newseg - } - seg := syscfg.Data[syscfg.segmap[name]] - return seg -} - -func (syscfg *SysConf) DeleteSeg(name string) error { - syscfg.lock.Lock() - defer syscfg.lock.Unlock() - if _, ok := syscfg.segmap[name]; !ok { - return errors.New("Seg Not Exists") - } - syscfg.Data[syscfg.segmap[name]] = nil - delete(syscfg.segmap, name) - return nil -} - -func (syscfg *SysSegment) GetComment(key string) string { - if v, ok := syscfg.nodeMap[key]; !ok { - return "" - } else { - return syscfg.NodeData[v].Comment - } -} - -func (syscfg *SysSegment) SetComment(key, comment string) error { - syscfg.lock.Lock() - defer syscfg.lock.Unlock() - if v, ok := syscfg.nodeMap[key]; !ok { - return errors.New("Key Not Exists") - } else { - syscfg.NodeData[v].Comment = comment - return nil - } -} - -func (syscfg *SysSegment) Exist(key string) bool { - if _, ok := syscfg.nodeMap[key]; !ok { - return false - } else { - return true - } -} - -func (syscfg *SysSegment) Get(key string) string { - if v, ok := syscfg.nodeMap[key]; !ok { - return "" - } else { - if len(syscfg.NodeData[v].Value) >= 1 { - return syscfg.NodeData[v].Value[0] - } - } - return "" -} - -func (syscfg *SysSegment) GetAll(key string) []string { - if v, ok := syscfg.nodeMap[key]; !ok { - return []string{} - } else { - return syscfg.NodeData[v].Value - } -} - -func (syscfg *SysSegment) Int(key string) int { - val := syscfg.Get(key) - if val == "" { - return 0 - } - res, _ := strconv.Atoi(val) - return res -} - -func (syscfg *SysSegment) Int64(key string) int64 { - val := syscfg.Get(key) - if val == "" { - return 0 - } - res, _ := strconv.ParseInt(val, 10, 64) - return res -} - -func (syscfg *SysSegment) Int32(key string) int32 { - val := syscfg.Get(key) - if val == "" { - return 0 - } - res, _ := strconv.ParseInt(val, 10, 32) - return int32(res) -} - -func (syscfg *SysSegment) Float64(key string) float64 { - val := syscfg.Get(key) - if val == "" { - return 0 - } - res, _ := strconv.ParseFloat(val, 64) - return res -} - -func (syscfg *SysSegment) Float32(key string) float32 { - val := syscfg.Get(key) - if val == "" { - return 0 - } - res, _ := strconv.ParseFloat(val, 32) - return float32(res) -} - -func (syscfg *SysSegment) Bool(key string) bool { - val := syscfg.Get(key) - if val == "" { - return false - } - res, _ := strconv.ParseBool(val) - return res -} - -func (syscfg *SysSegment) SetBool(key string, value bool, comment string) error { - res := strconv.FormatBool(value) - return syscfg.Set(key, res, comment) -} - -func (syscfg *SysSegment) SetFloat64(key string, prec int, value float64, comment string) error { - res := strconv.FormatFloat(value, 'f', prec, 64) - return syscfg.Set(key, res, comment) -} - -func (syscfg *SysSegment) SetFloat32(key string, prec int, value float32, comment string) error { - res := strconv.FormatFloat(float64(value), 'f', prec, 32) - return syscfg.Set(key, res, comment) -} - -func (syscfg *SysSegment) SetUint64(key string, value uint64, comment string) error { - res := strconv.FormatUint(value, 10) - return syscfg.Set(key, res, comment) -} - -func (syscfg *SysSegment) SetInt64(key string, value int64, comment string) error { - res := strconv.FormatInt(value, 10) - return syscfg.Set(key, res, comment) -} - -func (syscfg *SysSegment) SetInt32(key string, value int32, comment string) error { - res := strconv.FormatInt(int64(value), 10) - return syscfg.Set(key, res, comment) -} - -func (syscfg *SysSegment) SetInt(key string, value int, comment string) error { - res := strconv.Itoa(value) - return syscfg.Set(key, res, comment) -} - -func (syscfg *SysSegment) Set(key, value, comment string) error { - syscfg.lock.Lock() - defer syscfg.lock.Unlock() - if v, ok := syscfg.nodeMap[key]; !ok { - node := new(SysNode) - node.Key = key - node.Value = append(node.Value, value) - node.Comment = comment - syscfg.NodeData = append(syscfg.NodeData, node) - syscfg.nodeMap[key] = syscfg.nodeId - syscfg.nodeId++ - return nil - } else { - syscfg.NodeData[v].Value = []string{value} - if comment != "" { - syscfg.NodeData[v].Comment = comment - } - } - return nil -} -func (syscfg *SysSegment) SetAll(key string, value []string, comment string) error { - syscfg.lock.Lock() - defer syscfg.lock.Unlock() - if v, ok := syscfg.nodeMap[key]; !ok { - node := new(SysNode) - node.Key = key - node.Value = value - node.Comment = comment - syscfg.NodeData = append(syscfg.NodeData, node) - syscfg.nodeMap[key] = syscfg.nodeId - syscfg.nodeId++ - return nil - } else { - syscfg.NodeData[v].Value = value - if comment != "" { - syscfg.NodeData[v].Comment = comment - } - } - return nil -} -func (syscfg *SysSegment) AddValue(key, value, comment string) error { - syscfg.lock.Lock() - defer syscfg.lock.Unlock() - if v, ok := syscfg.nodeMap[key]; !ok { - node := new(SysNode) - node.Key = key - node.Value = append(node.Value, value) - node.Comment = comment - syscfg.NodeData = append(syscfg.NodeData, node) - syscfg.nodeMap[key] = syscfg.nodeId - syscfg.nodeId++ - return nil - } else { - syscfg.NodeData[v].Value = append(syscfg.NodeData[v].Value, value) - if comment != "" { - syscfg.NodeData[v].Comment = comment - } - } - return nil -} - -func (syscfg *SysSegment) Delete(key string) error { - syscfg.lock.Lock() - defer syscfg.lock.Unlock() - if v, ok := syscfg.nodeMap[key]; !ok { - return errors.New("Key not exists!") - } else { - if syscfg.NodeData[v].Comment != "" { - cmtSet := false - for j := v - 1; j >= 0; j-- { - if syscfg.NodeData[j] != nil { - syscfg.NodeData[j].Comment += syscfg.NodeData[v].Comment - cmtSet = true - break - } - } - if !cmtSet { - syscfg.Comment += syscfg.NodeData[v].Comment - } - } - syscfg.NodeData[v] = nil - delete(syscfg.nodeMap, key) - } - return nil -} - -func (syscfg *SysSegment) DeleteValue(key string, Value string) error { - syscfg.lock.Lock() - defer syscfg.lock.Unlock() - if v, ok := syscfg.nodeMap[key]; !ok { - return errors.New("Key not exists!") - } else { - data := syscfg.NodeData[v].Value - var vals []string - for _, v := range data { - if v != Value { - vals = append(vals, v) - } - } - syscfg.NodeData[v].Value = vals - } - return nil -} diff --git a/sysconf/sysconf_test.go b/sysconf/sysconf_test.go index 76b589a..d81d3a4 100644 --- a/sysconf/sysconf_test.go +++ b/sysconf/sysconf_test.go @@ -1,67 +1,1046 @@ package sysconf import ( - "fmt" + "bytes" + "errors" + "os" + "path/filepath" + "strconv" "testing" + "time" ) -func Test_SliceIn(t *testing.T) { - slice := []string{"ok", "11", "22"} - fmt.Println(SliceIn(slice, "ok")) - fmt.Println(SliceIn(slice, []rune("22"))) - fmt.Println(SliceIn(slice, "342423r")) - fmt.Println(SliceIn(slice, 444)) -} - -func Test_Parse(t *testing.T) { - data := ` -1.2.3.4 'ok.com' -#2.3.4.5 ppp.eor -2.3.4.5 'pp.com' -5.6.7.8 'ok.com' -` - cfg := new(SysConf) - cfg.SegStart = "[" - cfg.SegEnd = "]" - cfg.ValueFlag = "'" - cfg.EqualFlag = " " - //cfg.CommentCR = true - cfg.CommentFlag = []string{"#"} - cfg.EscapeFlag = "\\" - cfg.HaveSegMent = false - cfg.segmap = make(map[string]int64) - fmt.Println(cfg.Parse([]byte(data))) - cfg.Reverse() - cfg.Data[0].Delete(`pp.com`) - //fmt.Println(cfg.Data[0].comment) - fmt.Println(string(cfg.Build())) -} - -type slicetest struct { - A string `seg:"s" key:"a"` - B string `seg:"a" key:"b"` -} - -type testme struct { - Love slicetest `seg:"love"` - Star slicetest `seg:"star"` -} - -func Test_Marshal(t *testing.T) { - - var info string = ` -[love] -a=abc -b=123 -[star] -a=456 -b=789 -` - var tmp testme +func TestIniParseBuild(t *testing.T) { ini := NewIni() - ini.Parse([]byte(info)) - ini.Unmarshal(&tmp) - fmt.Printf("%+v\n", tmp) - b, _ := ini.Marshal(tmp) - fmt.Println(string(b)) + input := []byte("[app]\r\nname = demo\r\nname = second\r\nflag\r\n\r\n[app]\r\nother=value\r\n") + if err := ini.Parse(input); err != nil { + t.Fatalf("parse failed: %v", err) + } + if got := ini.Get("app", "name"); got != "demo" { + t.Fatalf("unexpected first value: %q", got) + } + if got := ini.GetAll("app", "name"); len(got) != 2 { + t.Fatalf("expected duplicate values, got %v", got) + } + if got := len(ini.Sections("app")); got != 2 { + t.Fatalf("expected duplicate sections, got %d", got) + } + if !ini.Section("app").Exist("flag") { + t.Fatalf("expected no-value key") + } + if out := ini.Build(); !bytes.Contains(out, []byte("name = demo")) || !bytes.Contains(out, []byte("[app]\r\nother=value")) { + t.Fatalf("expected lossless build, got: %q", out) + } + if err := ini.Save("/tmp/sysconf-ini-test.ini"); err != nil { + t.Fatalf("save failed: %v", err) + } +} + +func TestIniSetAllReplacesDuplicates(t *testing.T) { + ini := NewIni() + if err := ini.Parse([]byte("[app]\nname=a\nname=b\n")); err != nil { + t.Fatalf("parse failed: %v", err) + } + sec := ini.Section("app") + if sec == nil { + t.Fatalf("missing section") + } + if err := sec.SetAll("name", []string{"x"}, ""); err != nil { + t.Fatalf("setall failed: %v", err) + } + if got := sec.GetAll("name"); len(got) != 1 || got[0] != "x" { + t.Fatalf("unexpected values after setall: %v", got) + } + if out := string(ini.Build()); bytes.Contains([]byte(out), []byte("name=b")) { + t.Fatalf("duplicate value leaked into build: %q", out) + } +} + +func TestIniMarshalUnmarshal(t *testing.T) { + type nested struct { + Host string `key:"host"` + Port int `key:"port"` + Tags []string `key:"tag"` + Meta map[string]string `key:"meta"` + Skip string `key:"-"` + } + type cfg struct { + App nested `seg:"app"` + } + src := cfg{App: nested{ + Host: "127.0.0.1", + Port: 8080, + Tags: []string{"alpha", "beta"}, + Meta: map[string]string{"b": "second", "a": "first"}, + Skip: "hidden", + }} + ini := NewIni() + out, err := ini.Marshal(src) + if err != nil { + t.Fatalf("marshal failed: %v", err) + } + if !bytes.Contains(out, []byte("host=127.0.0.1")) { + t.Fatalf("marshal output missing host: %s", out) + } + if bytes.Contains(out, []byte("tag=[alpha beta]")) || !bytes.Contains(out, []byte("tag=alpha\ntag=beta")) { + t.Fatalf("marshal output should write repeated tag keys: %s", out) + } + if !bytes.Contains(out, []byte("meta=a=first\nmeta=b=second")) { + t.Fatalf("marshal output should write map keys and values: %s", out) + } + if bytes.Contains(out, []byte("skip=")) { + t.Fatalf("marshal output should skip key:\"-\" fields: %s", out) + } + if ini.Section("app") != nil && ini.Section("app").Exist("skip") { + t.Fatalf("marshal output should not create skip key") + } + if err := ini.Parse(out); err != nil { + t.Fatalf("parse marshaled output failed: %v", err) + } + var dst cfg + if err := ini.Unmarshal(&dst); err != nil { + t.Fatalf("unmarshal failed: %v", err) + } + if dst.App.Host != "127.0.0.1" || dst.App.Port != 8080 || len(dst.App.Tags) != 2 || dst.App.Tags[0] != "alpha" || dst.App.Tags[1] != "beta" { + t.Fatalf("unexpected result: %+v", dst.App) + } + if dst.App.Meta["a"] != "first" || dst.App.Meta["b"] != "second" { + t.Fatalf("unexpected map result: %+v", dst.App.Meta) + } + if dst.App.Skip != "" { + t.Fatalf("unmarshal should keep skip field empty, got %q", dst.App.Skip) + } + + input := NewIni() + if err := input.Parse([]byte("[app]\nskip=input-only\nhost=127.0.0.1\nport=8080\n")); err != nil { + t.Fatalf("parse skip input failed: %v", err) + } + var skipDst cfg + if err := input.Unmarshal(&skipDst); err != nil { + t.Fatalf("unmarshal skip input failed: %v", err) + } + if skipDst.App.Skip != "" { + t.Fatalf("unmarshal should ignore skip key even when input contains it, got %q", skipDst.App.Skip) + } + if skipDst.App.Host != "127.0.0.1" || skipDst.App.Port != 8080 { + t.Fatalf("unmarshal should still bind normal fields: %+v", skipDst.App) + } + input.Set("app", "skip", "input-only") + if got := input.Get("app", "skip"); got != "input-only" { + t.Fatalf("input setup failed: %q", got) + } + if err := input.Unmarshal(&skipDst); err != nil { + t.Fatalf("unmarshal with skip key failed: %v", err) + } + if skipDst.App.Skip != "" { + t.Fatalf("unmarshal should ignore skip key even after set, got %q", skipDst.App.Skip) + } + if err := ini.Unmarshal(&dst); err != nil { + t.Fatalf("unmarshal failed: %v", err) + } +} + +func TestIniMarshalUsesReceiverProfile(t *testing.T) { + type app struct { + Name string `key:"name"` + } + type cfg struct { + App app `seg:"app"` + } + + ini := NewIniWithProfiles(LinuxConfProfile(":")) + out, err := ini.Marshal(cfg{App: app{Name: "demo"}}) + if err != nil { + t.Fatalf("marshal failed: %v", err) + } + if !bytes.Contains(out, []byte("name:demo")) { + t.Fatalf("marshal should use receiver delimiter, got %q", out) + } + if bytes.Contains(out, []byte("name=demo")) { + t.Fatalf("marshal should not fall back to default delimiter: %q", out) + } +} + +func TestIniDuplicateSectionsBindAcrossAllSections(t *testing.T) { + type app struct { + Name string `key:"name"` + Port int `key:"port"` + Tags []string `key:"tag"` + } + type cfg struct { + App app `seg:"app"` + } + + ini := NewIni() + if err := ini.Parse([]byte("[app]\nname=one\ntag=alpha\n[app]\nport=2\ntag=beta\n")); err != nil { + t.Fatalf("parse failed: %v", err) + } + if got := ini.Get("app", "port"); got != "2" { + t.Fatalf("Get should see later duplicate section key, got %q", got) + } + if !ini.Has("app", "port") { + t.Fatal("Has should see later duplicate section key") + } + if got := ini.GetAll("app", "tag"); len(got) != 2 || got[0] != "alpha" || got[1] != "beta" { + t.Fatalf("GetAll should aggregate duplicate sections, got %#v", got) + } + var out cfg + if err := ini.Unmarshal(&out); err != nil { + t.Fatalf("unmarshal failed: %v", err) + } + if out.App.Name != "one" || out.App.Port != 2 { + t.Fatalf("duplicate-section bind mismatch: %+v", out.App) + } + if len(out.App.Tags) != 2 || out.App.Tags[0] != "alpha" || out.App.Tags[1] != "beta" { + t.Fatalf("duplicate-section repeated values mismatch: %#v", out.App.Tags) + } +} + +func TestIniMarshalAndUnmarshalNestedPointerSection(t *testing.T) { + type server struct { + Host string `key:"host"` + Port int `key:"port"` + } + type cfg struct { + Server *server `seg:"server"` + } + + ini := NewIni() + out, err := ini.Marshal(cfg{Server: &server{Host: "127.0.0.1", Port: 8080}}) + if err != nil { + t.Fatalf("marshal failed: %v", err) + } + if !bytes.Contains(out, []byte("[server]\nhost=127.0.0.1\nport=8080")) { + t.Fatalf("marshal should emit pointer section fields, got %q", out) + } + + parsed := NewIni() + if err := parsed.Parse(out); err != nil { + t.Fatalf("parse marshaled output failed: %v", err) + } + var got cfg + if err := parsed.Unmarshal(&got); err != nil { + t.Fatalf("unmarshal failed: %v", err) + } + if got.Server == nil || got.Server.Host != "127.0.0.1" || got.Server.Port != 8080 { + t.Fatalf("pointer section round-trip mismatch: %+v", got.Server) + } +} + +func TestIniMarshalUnmarshalRootSection(t *testing.T) { + type cfg struct { + Root string `seg:"" key:"root"` + } + + ini := NewIni() + out, err := ini.Marshal(cfg{Root: "ok"}) + if err != nil { + t.Fatalf("marshal failed: %v", err) + } + if !bytes.Contains(out, []byte("root=ok")) { + t.Fatalf("marshal should include root key, got %q", out) + } + + parsed := NewIni() + if err := parsed.Parse([]byte("root=from-input\n")); err != nil { + t.Fatalf("parse root input failed: %v", err) + } + var got cfg + if err := parsed.Unmarshal(&got); err != nil { + t.Fatalf("unmarshal root input failed: %v", err) + } + if got.Root != "from-input" { + t.Fatalf("root key did not bind: %+v", got) + } +} + +func TestIniUnmarshalSkipsMissingKeys(t *testing.T) { + type app struct { + Host string `key:"host"` + Port int `key:"port"` + Enabled bool `key:"enabled"` + Tags []string `key:"tag"` + Meta map[string]string `key:"meta"` + } + type cfg struct { + App app `seg:"app"` + } + + ini := NewIni() + if err := ini.Parse([]byte("[app]\nhost=127.0.0.1\n")); err != nil { + t.Fatalf("parse failed: %v", err) + } + got := cfg{App: app{ + Port: 8080, + Enabled: true, + Tags: []string{"keep"}, + Meta: map[string]string{"keep": "value"}, + }} + if err := ini.Unmarshal(&got); err != nil { + t.Fatalf("unmarshal should ignore missing keys, got %v", err) + } + if got.App.Host != "127.0.0.1" || got.App.Port != 8080 || !got.App.Enabled { + t.Fatalf("missing scalar keys should not overwrite existing values: %+v", got.App) + } + if len(got.App.Tags) != 1 || got.App.Tags[0] != "keep" || got.App.Meta["keep"] != "value" { + t.Fatalf("missing collection keys should not overwrite existing values: %+v", got.App) + } +} + +func TestMarshalCSVSkipsUnexportedFields(t *testing.T) { + type row struct { + a string + B string + } + out, err := MarshalCSV([]string{"B"}, []row{{a: "hidden", B: "shown"}}) + if err != nil { + t.Fatalf("marshal csv failed: %v", err) + } + if !bytes.Contains(out, []byte("shown")) { + t.Fatalf("expected exported field in csv, got %q", out) + } +} + +func TestMarshalCSVRejectsMismatchedRowLength(t *testing.T) { + if _, err := MarshalCSV([]string{"A", "B"}, [][]string{{"a", "b"}, {"c"}}); err == nil { + t.Fatal("expected header row length mismatch error") + } + if _, err := MarshalCSV(nil, [][]string{{"a", "b"}, {"c"}}); err == nil { + t.Fatal("expected inferred row length mismatch error") + } +} + +func TestIniNoValueInlineCommentAndQuotedComment(t *testing.T) { + ini := NewIni() + input := []byte("[app]\nflag # enabled\nvalue=\"a'b # still value\" # comment\n") + if err := ini.Parse(input); err != nil { + t.Fatalf("parse failed: %v", err) + } + sec := ini.Section("app") + if sec == nil { + t.Fatalf("missing app section") + } + if !sec.Exist("flag") || sec.Comment("flag") != "enabled" { + t.Fatalf("no-value inline comment parsed incorrectly") + } + if got := sec.Get("value"); got != "a'b # still value" { + t.Fatalf("quoted comment parsed incorrectly: %q", got) + } + if got := sec.Comment("value"); got != "comment" { + t.Fatalf("value comment parsed incorrectly: %q", got) + } +} + +func TestIniParsesCommonDelimiterSectionCommentAndContinuation(t *testing.T) { + ini := NewIni() + input := []byte("[app] ; header comment\nname: demo\nurl = http://example.test/a#frag\nmessage = first \\\n second # tail\nquoted = \"a # b\"\n") + if err := ini.Parse(input); err != nil { + t.Fatalf("parse failed: %v", err) + } + sec := ini.Section("app") + if sec == nil { + t.Fatalf("missing app section") + } + if sec.HeaderComment != "header comment" { + t.Fatalf("section header comment parsed incorrectly: %q", sec.HeaderComment) + } + if got := sec.Get("name"); got != "demo" { + t.Fatalf("colon-delimited value parsed incorrectly: %q", got) + } + if entry := sec.Entry("name"); entry == nil || entry.Delimiter != ":" { + t.Fatalf("colon delimiter was not preserved: %#v", entry) + } + if got := sec.Get("url"); got != "http://example.test/a#frag" { + t.Fatalf("hash without leading space should stay in value: %q", got) + } + if got := sec.Get("message"); got != "first second" { + t.Fatalf("continued value parsed incorrectly: %q", got) + } + if got := sec.Comment("message"); got != "tail" { + t.Fatalf("continued line comment parsed incorrectly: %q", got) + } + if got := sec.Get("quoted"); got != "a # b" { + t.Fatalf("quoted value parsed incorrectly: %q", got) + } + if out := ini.Build(); !bytes.Contains(out, []byte("message = first \\\n second # tail\n")) { + t.Fatalf("unchanged continuation was not preserved: %q", out) + } +} + +func TestIniWriteQuotesAmbiguousValues(t *testing.T) { + ini := NewIni() + ini.Set("app", "hash", "value # not comment") + ini.Set("app", "space", " leading") + ini.Set("app", "line", "a\nb") + + out := ini.Build() + for _, want := range [][]byte{ + []byte(`hash="value # not comment"`), + []byte(`space=" leading"`), + []byte(`line="a\nb"`), + } { + if !bytes.Contains(out, want) { + t.Fatalf("quoted output missing %q in %q", want, out) + } + } + + roundTrip := NewIni() + if err := roundTrip.Parse(out); err != nil { + t.Fatalf("round-trip parse failed: %v", err) + } + if got := roundTrip.Get("app", "hash"); got != "value # not comment" { + t.Fatalf("quoted hash value did not round trip: %q", got) + } + if got := roundTrip.Get("app", "space"); got != " leading" { + t.Fatalf("quoted leading space did not round trip: %q", got) + } + if got := roundTrip.Get("app", "line"); got != "a\nb" { + t.Fatalf("quoted newline did not round trip: %q", got) + } +} + +func TestIniStrictParseErrorReportsLocation(t *testing.T) { + ini := NewIni() + ini.Strict = true + err := ini.Parse([]byte("[app]\n=value\n")) + var parseErr *ParseError + if !errors.As(err, &parseErr) { + t.Fatalf("expected ParseError, got %T: %v", err, err) + } + if parseErr.Line != 2 || parseErr.Column != 1 { + t.Fatalf("unexpected parse error location: line=%d column=%d", parseErr.Line, parseErr.Column) + } + if parseErr.Message == "" { + t.Fatalf("parse error should include message") + } +} + +func TestIniSectionRenameRebuildsHeader(t *testing.T) { + ini := NewIni() + if err := ini.Parse([]byte("[old]\nname=value\n")); err != nil { + t.Fatalf("parse failed: %v", err) + } + sec := ini.Section("old") + if sec == nil { + t.Fatalf("missing old section") + } + sec.Name = "new" + if out := ini.Build(); !bytes.Contains(out, []byte("[new]\n")) || bytes.Contains(out, []byte("[old]\n")) { + t.Fatalf("section rename not reflected in build: %q", out) + } + if got := ini.Section("new"); got != sec { + t.Fatalf("renamed section was not indexed under new name") + } + if got := ini.Section("old"); got != nil { + t.Fatalf("renamed section still indexed under old name: %#v", got) + } +} + +func TestParseCSVPreservesBoundaryWhitespace(t *testing.T) { + csvData, err := ParseCSV([]byte(" col ,name\n value ,demo \n"), true) + if err != nil { + t.Fatalf("parse csv failed: %v", err) + } + if got := csvData.Header()[0]; got != " col " { + t.Fatalf("header whitespace was trimmed: %q", got) + } + row := csvData.Row(0) + if row == nil || row.Col(0).value != " value " || row.Col(1).value != "demo " { + t.Fatalf("row whitespace was not preserved: %#v", row) + } +} + +type configFrameworkApp struct { + Name string `key:"name" required:"true"` + Port int `key:"port" default:"8080"` + Enabled bool `key:"enabled" default:"true"` + Timeout time.Duration `key:"timeout" default:"2s"` + Retries []int `key:"retry" default:"1,2" split:","` + Tags []string `key:"tag"` + Limits map[string]int `key:"limit" default:"read=10,write=20" split:","` + Token string `key:"token" env:"APP_SECRET"` + SkipEnv string `key:"skip_env" env:"-" default:"file-only"` +} + +type configFrameworkServer struct { + Host string `key:"host" default:"127.0.0.1"` +} + +type configFrameworkExample struct { + App configFrameworkApp `seg:"app"` + Server configFrameworkServer `seg:"server"` +} + +func (c *configFrameworkExample) Validate() error { + if c.App.Port <= 0 { + return errors.New("port must be positive") + } + return nil +} + +func TestConfigFrameworkLoadsOverridesDefaultsEnvAndValidate(t *testing.T) { + dir := t.TempDir() + base := filepath.Join(dir, "base.ini") + override := filepath.Join(dir, "override.ini") + if err := os.WriteFile(base, []byte("[app]\nname=demo\nport=1000\ntag=base\nretry=3\nlimit=read=11\nskip_env=from-file\n[server]\nhost=0.0.0.0\n"), 0o644); err != nil { + t.Fatalf("write base config failed: %v", err) + } + if err := os.WriteFile(override, []byte("[app]\nport=2000\ntag=override\ntag=extra\nlimit=write=22\nenabled\n"), 0o644); err != nil { + t.Fatalf("write override config failed: %v", err) + } + + env := map[string]string{ + "APP_APP_PORT": "3000", + "APP_APP_TIMEOUT": "5s", + "APP_APP_RETRY": "7,8,9", + "APP_SECRET": "token-from-env", + "APP_APP_SKIP_ENV": "ignored", + } + var dst configFrameworkExample + cfg, err := LoadConfig(&dst, []string{base, override}, + WithEnvPrefix("APP"), + WithEnvLookup(func(key string) (string, bool) { + value, ok := env[key] + return value, ok + }), + ) + if err != nil { + t.Fatalf("load config failed: %v", err) + } + if dst.App.Name != "demo" || dst.App.Port != 3000 || !dst.App.Enabled { + t.Fatalf("basic bind mismatch: %+v", dst.App) + } + if dst.App.Timeout != 5*time.Second { + t.Fatalf("duration env override mismatch: %s", dst.App.Timeout) + } + if got := dst.App.Retries; len(got) != 3 || got[0] != 7 || got[1] != 8 || got[2] != 9 { + t.Fatalf("slice env override mismatch: %#v", got) + } + if got := dst.App.Tags; len(got) != 2 || got[0] != "override" || got[1] != "extra" { + t.Fatalf("repeated key override mismatch: %#v", got) + } + if dst.App.Limits["write"] != 22 || dst.App.Limits["read"] != 0 { + t.Fatalf("map override mismatch: %#v", dst.App.Limits) + } + if dst.App.Token != "token-from-env" || dst.App.SkipEnv != "from-file" { + t.Fatalf("env handling mismatch: token=%q skip=%q", dst.App.Token, dst.App.SkipEnv) + } + if dst.Server.Host != "0.0.0.0" { + t.Fatalf("nested section bind mismatch: %q", dst.Server.Host) + } + if cfg.Get("app", "port") != "3000" { + t.Fatalf("config access did not see env override: %q", cfg.Get("app", "port")) + } + if values := cfg.GetAll("app", "tag"); len(values) != 2 || values[0] != "override" || values[1] != "extra" { + t.Fatalf("config repeated values mismatch: %#v", values) + } + + cfg.Set("app", "name", "saved") + outPath := filepath.Join(dir, "saved.ini") + if err := cfg.Save(outPath); err != nil { + t.Fatalf("save config failed: %v", err) + } + out, err := os.ReadFile(outPath) + if err != nil { + t.Fatalf("read saved config failed: %v", err) + } + if !bytes.Contains(out, []byte("name=saved")) || !bytes.Contains(out, []byte("retry=7")) || !bytes.Contains(out, []byte("retry=8")) || !bytes.Contains(out, []byte("retry=9")) { + t.Fatalf("saved config missing expected values: %q", out) + } +} + +func TestConfigFrameworkReportsRequiredAndValidateErrors(t *testing.T) { + var missing configFrameworkExample + _, err := LoadConfig(&missing, nil) + var cfgErr *ConfigError + if !errors.As(err, &cfgErr) { + t.Fatalf("expected required ConfigError, got %T: %v", err, err) + } + if cfgErr.Section != "app" || cfgErr.Key != "name" { + t.Fatalf("required error points at wrong field: %#v", cfgErr) + } + + var invalid configFrameworkExample + _, err = LoadConfig(&invalid, nil, WithEnvLookup(func(key string) (string, bool) { + switch key { + case "APP_NAME", "APP_APP_NAME": + return "demo", true + case "APP_APP_PORT": + return "-1", true + default: + return "", false + } + }), WithEnvPrefix("APP")) + if err == nil || err.Error() != "port must be positive" { + t.Fatalf("expected validate error, got %v", err) + } +} + +func TestConfigFrameworkEnvIsExplicitOptIn(t *testing.T) { + t.Setenv("APP_NAME", "from-env") + t.Setenv("APP_PORT", "9090") + + type cfg struct { + Name string `key:"name" required:"true"` + Port int `key:"port" default:"8080"` + } + + var disabled cfg + if _, err := LoadConfig(&disabled, nil); err == nil { + t.Fatalf("expected missing required value when env is not enabled") + } + + var enabled cfg + if _, err := LoadConfig(&enabled, nil, WithEnvPrefix("APP")); err != nil { + t.Fatalf("load with explicit env failed: %v", err) + } + if enabled.Name != "from-env" || enabled.Port != 9090 { + t.Fatalf("env override mismatch: %+v", enabled) + } +} + +func TestConfigFrameworkRequiredNoValueDependsOnFieldType(t *testing.T) { + type cfg struct { + App struct { + Name string `key:"name" required:"true"` + Enabled bool `key:"enabled" required:"true"` + } `seg:"app"` + } + + var missingName cfg + loader := NewConfig() + if err := loader.LoadBytes([]byte("[app]\nname\nenabled\n")); err != nil { + t.Fatalf("load bytes failed: %v", err) + } + err := loader.Bind(&missingName) + var cfgErr *ConfigError + if !errors.As(err, &cfgErr) { + t.Fatalf("expected no-value string required error, got %T: %v", err, err) + } + if cfgErr.Key != "name" || cfgErr.Reason != "required value is empty" { + t.Fatalf("unexpected required error: %#v", cfgErr) + } + + var ok cfg + loader = NewConfig() + if err := loader.LoadBytes([]byte("[app]\nname=demo\nenabled\n")); err != nil { + t.Fatalf("load bytes failed: %v", err) + } + if err := loader.Bind(&ok); err != nil { + t.Fatalf("bind with bool no-value failed: %v", err) + } + if ok.App.Name != "demo" || !ok.App.Enabled { + t.Fatalf("unexpected bind result: %+v", ok.App) + } +} + +func TestConfigFrameworkBindsDuplicateSectionsInOrder(t *testing.T) { + type cfg struct { + App struct { + Name string `key:"name"` + Port int `key:"port"` + Tags []string `key:"tag"` + } `seg:"app"` + } + + loader := NewConfig() + if err := loader.LoadBytes([]byte("[app]\nname=demo\ntag=base\n[app]\nport=2000\ntag=override\n")); err != nil { + t.Fatalf("load bytes failed: %v", err) + } + + var got cfg + if err := loader.Bind(&got); err != nil { + t.Fatalf("bind failed: %v", err) + } + if got.App.Name != "demo" || got.App.Port != 2000 { + t.Fatalf("duplicate section bind mismatch: %+v", got.App) + } + if len(got.App.Tags) != 2 || got.App.Tags[0] != "base" || got.App.Tags[1] != "override" { + t.Fatalf("duplicate section repeated values mismatch: %#v", got.App.Tags) + } +} + +func TestConfigFrameworkSplitTagIsExplicit(t *testing.T) { + type cfg struct { + App struct { + Implicit []string `key:"implicit" default:"a,b"` + CSV []string `key:"csv" default:"x,y" split:","` + } `seg:"app"` + } + + var got cfg + if _, err := LoadConfig(&got, nil); err != nil { + t.Fatalf("load config failed: %v", err) + } + if len(got.App.Implicit) != 1 || got.App.Implicit[0] != "a,b" { + t.Fatalf("implicit split should stay scalar-like: %#v", got.App.Implicit) + } + if len(got.App.CSV) != 2 || got.App.CSV[0] != "x" || got.App.CSV[1] != "y" { + t.Fatalf("explicit split did not apply: %#v", got.App.CSV) + } +} + +func TestConfigFrameworkSplitTagAppliesToFileBinding(t *testing.T) { + type cfg struct { + App struct { + Retries []int `key:"retry" split:"|"` + Tags []string `key:"tag" split:","` + Limits map[string]int `key:"limit" split:";"` + } `seg:"app"` + } + + loader := NewConfig() + if err := loader.LoadBytes([]byte("[app]\nretry=1|2|3\ntag=alpha,beta\ntag=gamma\nlimit=read=10; write=20\n")); err != nil { + t.Fatalf("load bytes failed: %v", err) + } + + var got cfg + if err := loader.Bind(&got); err != nil { + t.Fatalf("bind failed: %v", err) + } + if len(got.App.Retries) != 3 || got.App.Retries[0] != 1 || got.App.Retries[1] != 2 || got.App.Retries[2] != 3 { + t.Fatalf("retry split mismatch: %#v", got.App.Retries) + } + if len(got.App.Tags) != 3 || got.App.Tags[0] != "alpha" || got.App.Tags[1] != "beta" || got.App.Tags[2] != "gamma" { + t.Fatalf("tag split mismatch: %#v", got.App.Tags) + } + if got.App.Limits["read"] != 10 || got.App.Limits["write"] != 20 { + t.Fatalf("limit split mismatch: %#v", got.App.Limits) + } +} + +func TestConfigFrameworkSourcesAndAtomicSave(t *testing.T) { + dir := t.TempDir() + required := filepath.Join(dir, "app.ini") + missing := filepath.Join(dir, "missing.ini") + out := filepath.Join(dir, "saved.ini") + + if err := os.WriteFile(required, []byte("[app]\nname=demo\n"), 0o644); err != nil { + t.Fatalf("write required config failed: %v", err) + } + if err := os.WriteFile(out, []byte("stale\n"), 0o600); err != nil { + t.Fatalf("write existing output failed: %v", err) + } + + type cfg struct { + App struct { + Name string `key:"name" required:"true"` + } `seg:"app"` + } + + var got cfg + loaded, err := LoadConfigSources(&got, []ConfigSource{ + OptionalFile(missing), + RequiredFile(required), + }) + if err != nil { + t.Fatalf("load sources failed: %v", err) + } + if got.App.Name != "demo" { + t.Fatalf("unexpected loaded config: %+v", got) + } + + loaded.Set("app", "name", "saved") + if err := loaded.SaveAtomic(out); err != nil { + t.Fatalf("save atomic failed: %v", err) + } + data, err := os.ReadFile(out) + if err != nil { + t.Fatalf("read saved file failed: %v", err) + } + if !bytes.Contains(data, []byte("name=saved")) { + t.Fatalf("saved file missing updated value: %q", data) + } + info, err := os.Stat(out) + if err != nil { + t.Fatalf("stat saved file failed: %v", err) + } + if info.Mode().Perm() != 0o600 { + t.Fatalf("save atomic should preserve file mode, got %o", info.Mode().Perm()) + } + + _, err = LoadConfigSources(&cfg{}, []ConfigSource{RequiredFile(missing)}) + var sourceErr *ConfigSourceError + if !errors.As(err, &sourceErr) { + t.Fatalf("expected ConfigSourceError, got %T: %v", err, err) + } + if sourceErr.Path != missing || sourceErr.Optional { + t.Fatalf("unexpected source error metadata: %#v", sourceErr) + } + if !errors.Is(err, os.ErrNotExist) { + t.Fatalf("expected wrapped not-exist error, got %v", err) + } +} + +func TestConfigFrameworkMemorySourceAndTypedGetters(t *testing.T) { + cfg := NewConfig() + if err := cfg.LoadSources(StringSource("inline", "[app]\nname=demo\nport=8080\nenabled\nratio=1.5\ntimeout=3s\n")); err != nil { + t.Fatalf("load string source failed: %v", err) + } + + name, err := cfg.GetStringE("app", "name") + if err != nil || name != "demo" { + t.Fatalf("get string mismatch: name=%q err=%v", name, err) + } + port, err := cfg.GetIntE("app", "port") + if err != nil || port != 8080 { + t.Fatalf("get int mismatch: port=%d err=%v", port, err) + } + enabled, err := cfg.GetBoolE("app", "enabled") + if err != nil || !enabled { + t.Fatalf("get bool no-value mismatch: enabled=%v err=%v", enabled, err) + } + ratio, err := cfg.GetFloat64E("app", "ratio") + if err != nil || ratio != 1.5 { + t.Fatalf("get float mismatch: ratio=%v err=%v", ratio, err) + } + timeout, err := cfg.GetDurationE("app", "timeout") + if err != nil || timeout != 3*time.Second { + t.Fatalf("get duration mismatch: timeout=%v err=%v", timeout, err) + } + + if _, err := cfg.GetIntE("app", "missing"); !errors.Is(err, ErrKeyNotFound) { + t.Fatalf("missing typed getter should wrap ErrKeyNotFound, got %v", err) + } + cfg.Set("app", "bad", "not-int") + err = nil + if _, err = cfg.GetIntE("app", "bad"); !errors.Is(err, strconv.ErrSyntax) { + t.Fatalf("invalid typed getter should wrap strconv.ErrSyntax, got %v", err) + } +} + +func TestConfigFrameworkSetStructWritesConfig(t *testing.T) { + type cfg struct { + App struct { + Name string `key:"name"` + Port int `key:"port"` + Enabled bool `key:"enabled"` + Timeout time.Duration `key:"timeout"` + Tags []string `key:"tag"` + Limits map[string]uint64 `key:"limit"` + } `seg:"app"` + } + + src := cfg{} + src.App.Name = "demo" + src.App.Port = 9090 + src.App.Enabled = true + src.App.Timeout = 5 * time.Second + src.App.Tags = []string{"alpha", "beta"} + src.App.Limits = map[string]uint64{"write": 20, "read": 10} + + loader := NewConfig() + if err := loader.SetStruct(src); err != nil { + t.Fatalf("set struct failed: %v", err) + } + if got := loader.Get("app", "name"); got != "demo" { + t.Fatalf("name was not written: %q", got) + } + if got := loader.Get("app", "port"); got != "9090" { + t.Fatalf("port was not written: %q", got) + } + if got := loader.Get("app", "timeout"); got != "5s" { + t.Fatalf("duration was not written: %q", got) + } + if got := loader.GetAll("app", "tag"); len(got) != 2 || got[0] != "alpha" || got[1] != "beta" { + t.Fatalf("slice was not written as repeated keys: %#v", got) + } + if got := loader.GetAll("app", "limit"); len(got) != 2 || got[0] != "read=10" || got[1] != "write=20" { + t.Fatalf("map was not written as sorted repeated keys: %#v", got) + } + + var roundTrip cfg + if err := loader.Bind(&roundTrip); err != nil { + t.Fatalf("round-trip bind failed: %v", err) + } + if roundTrip.App.Name != src.App.Name || roundTrip.App.Port != src.App.Port || roundTrip.App.Timeout != src.App.Timeout { + t.Fatalf("round-trip scalar mismatch: %+v", roundTrip.App) + } + if len(roundTrip.App.Tags) != 2 || roundTrip.App.Tags[0] != "alpha" || roundTrip.App.Tags[1] != "beta" { + t.Fatalf("round-trip slice mismatch: %#v", roundTrip.App.Tags) + } + if roundTrip.App.Limits["read"] != 10 || roundTrip.App.Limits["write"] != 20 { + t.Fatalf("round-trip map mismatch: %#v", roundTrip.App.Limits) + } +} + +func TestConfigFrameworkBindErrorUnwrapsOriginalError(t *testing.T) { + type cfg struct { + App struct { + Port int `key:"port"` + } `seg:"app"` + } + + loader := NewConfig() + if err := loader.LoadSources(BytesSource("bad", []byte("[app]\nport=bad\n"))); err != nil { + t.Fatalf("load bytes source failed: %v", err) + } + var got cfg + err := loader.Bind(&got) + var cfgErr *ConfigError + if !errors.As(err, &cfgErr) { + t.Fatalf("expected ConfigError, got %T: %v", err, err) + } + if cfgErr.Section != "app" || cfgErr.Key != "port" || cfgErr.Field != "Port" { + t.Fatalf("unexpected config error metadata: %#v", cfgErr) + } + if !errors.Is(err, strconv.ErrSyntax) { + t.Fatalf("bind error should unwrap strconv.ErrSyntax, got %v", err) + } +} + +func TestConfigFrameworkDescribeAndSampleConfig(t *testing.T) { + type server struct { + Host string `key:"host" default:"127.0.0.1"` + Ports []int `key:"port" default:"8080,8081" split:","` + } + type cfg struct { + App struct { + Name string `key:"name" required:"true"` + Token string `key:"token" env:"APP_TOKEN"` + Tags []string `key:"tag"` + Modes []string `key:"mode" required:"true"` + Limits map[string]int `key:"limit" default:"read=10,write=20" split:","` + Labels map[string]string `key:"label" required:"true"` + Skip string `key:"-"` + } `seg:"app"` + Server *server `seg:"server"` + } + + var src cfg + fields, err := DescribeConfig(&src) + if err != nil { + t.Fatalf("describe config failed: %v", err) + } + if src.Server != nil { + t.Fatalf("describe config should not allocate nil nested pointers") + } + if len(fields) != 8 { + t.Fatalf("unexpected field count: %#v", fields) + } + byPath := make(map[string]ConfigFieldInfo) + for _, field := range fields { + byPath[field.Field] = field + } + name := byPath["App.Name"] + if name.Section != "app" || name.Key != "name" || name.Default != "" || !name.Required || name.Type != "string" { + t.Fatalf("unexpected name field metadata: %#v", name) + } + ports := byPath["Server.Ports"] + if ports.Section != "server" || ports.Key != "port" || ports.Default != "8080,8081" || ports.Split != "," || ports.Type != "[]int" { + t.Fatalf("unexpected ports field metadata: %#v", ports) + } + if token := byPath["App.Token"]; token.Env != "APP_TOKEN" { + t.Fatalf("unexpected token env metadata: %#v", token) + } + if _, ok := byPath["App.Skip"]; ok { + t.Fatalf("key:\"-\" field should be skipped: %#v", byPath["App.Skip"]) + } + + sample, err := SampleConfig(&src) + if err != nil { + t.Fatalf("sample config failed: %v", err) + } + if src.Server != nil { + t.Fatalf("sample config should not allocate nil nested pointers") + } + for _, want := range [][]byte{ + []byte("[app]\n"), + []byte("name=value\n"), + []byte("token=\n"), + []byte("tag=\n"), + []byte("mode=value\n"), + []byte("limit=read=10\n"), + []byte("limit=write=20\n"), + []byte("label=key=value\n"), + []byte("[server]\n"), + []byte("host=127.0.0.1\n"), + []byte("port=8080\n"), + []byte("port=8081\n"), + } { + if !bytes.Contains(sample, want) { + t.Fatalf("sample config missing %q in %q", want, sample) + } + } + + var got cfg + loader := NewConfig() + if err := loader.LoadBytes(sample); err != nil { + t.Fatalf("load sample failed: %v", err) + } + if err := loader.Bind(&got); err != nil { + t.Fatalf("bind sample failed: %v", err) + } + if got.App.Name != "value" || got.App.Limits["read"] != 10 || got.App.Limits["write"] != 20 { + t.Fatalf("sample app values did not bind: %+v", got.App) + } + if len(got.App.Modes) != 1 || got.App.Modes[0] != "value" || got.App.Labels["key"] != "value" { + t.Fatalf("required placeholder values did not bind: %+v", got.App) + } + if got.Server == nil || got.Server.Host != "127.0.0.1" || len(got.Server.Ports) != 2 || got.Server.Ports[0] != 8080 || got.Server.Ports[1] != 8081 { + t.Fatalf("sample server values did not bind: %+v", got.Server) + } +} + +func TestConfigFrameworkFlattenSectionNamesAndKeys(t *testing.T) { + cfg := NewConfig() + cfg.Set("", "root", "top") + cfg.Set("app", "name", "demo") + cfg.SetAll("app", "tag", []string{"alpha", "beta"}) + cfg.Ini().AddValue("app", "tag", "gamma") + cfg.Set("server", "host", "127.0.0.1") + if err := cfg.LoadBytes([]byte("[app]\nflag\n")); err != nil { + t.Fatalf("load no-value config failed: %v", err) + } + + if got := cfg.SectionNames(); len(got) != 3 || got[0] != "" || got[1] != "app" || got[2] != "server" { + t.Fatalf("unexpected section names: %#v", got) + } + if got := cfg.Keys(""); len(got) != 1 || got[0] != "root" { + t.Fatalf("unexpected root keys: %#v", got) + } + if got := cfg.Keys("app"); len(got) != 3 || got[0] != "flag" || got[1] != "name" || got[2] != "tag" { + t.Fatalf("unexpected app keys: %#v", got) + } + + flat := cfg.Flatten() + if got := flat["root"]; len(got) != 1 || got[0] != "top" { + t.Fatalf("unexpected root flatten values: %#v", got) + } + if got := flat["app.name"]; len(got) != 1 || got[0] != "demo" { + t.Fatalf("unexpected app.name flatten values: %#v", got) + } + if got := flat["app.tag"]; len(got) != 3 || got[0] != "alpha" || got[1] != "beta" || got[2] != "gamma" { + t.Fatalf("unexpected app.tag flatten values: %#v", got) + } + if got := flat["app.flag"]; len(got) != 1 || got[0] != "" { + t.Fatalf("unexpected app.flag flatten values: %#v", got) + } +} + +func TestConfigFrameworkFlattenEntriesPreservesStructuredPath(t *testing.T) { + cfg := NewConfig() + cfg.Set("db.primary", "host", "127.0.0.1") + cfg.Set("db", "primary.host", "localhost") + + flat := cfg.Flatten() + if got := flat["db.primary.host"]; len(got) != 2 || got[0] != "127.0.0.1" || got[1] != "localhost" { + t.Fatalf("flatten should keep legacy ambiguous path values: %#v", got) + } + + entries := cfg.FlattenEntries() + if len(entries) != 2 { + t.Fatalf("unexpected flatten entry count: %#v", entries) + } + if entries[0].Section != "db.primary" || entries[0].Key != "host" || entries[0].Path != "db.primary.host" || len(entries[0].Values) != 1 || entries[0].Values[0] != "127.0.0.1" { + t.Fatalf("unexpected first flatten entry: %#v", entries[0]) + } + if entries[1].Section != "db" || entries[1].Key != "primary.host" || entries[1].Path != "db.primary.host" || len(entries[1].Values) != 1 || entries[1].Values[0] != "localhost" { + t.Fatalf("unexpected second flatten entry: %#v", entries[1]) + } + + entries[0].Values[0] = "mutated" + if got := cfg.Get("db.primary", "host"); got != "127.0.0.1" { + t.Fatalf("flatten entries should not expose mutable config values: %q", got) + } } diff --git a/sysconf/typed.go b/sysconf/typed.go index fbc7330..07562e7 100644 --- a/sysconf/typed.go +++ b/sysconf/typed.go @@ -1 +1,28 @@ package sysconf + +import "strconv" + +func (s *Section) Uint64(key string) uint64 { + v, _ := strconv.ParseUint(s.Get(key), 10, 64) + return v +} + +func (s *Section) MustInt(key string) int { + return s.Int(key) +} + +func (s *Section) MustInt64(key string) int64 { + return s.Int64(key) +} + +func (s *Section) MustUint64(key string) uint64 { + return s.Uint64(key) +} + +func (s *Section) MustBool(key string) bool { + return s.Bool(key) +} + +func (s *Section) MustFloat64(key string) float64 { + return s.Float64(key) +} diff --git a/typed.go b/typed.go index 942fb2e..6e0479d 100644 --- a/typed.go +++ b/typed.go @@ -21,12 +21,14 @@ const ( TCP_TIME_WAIT TCP_CLOSE TCP_CLOSE_WAIT - TCP_LAST_ACL + TCP_LAST_ACK TCP_LISTEN TCP_CLOSING ) -var TCP_STATE = []string{"TCP_UNKNOWN", "TCP_ESTABLISHED", "TCP_SYN_SENT", "TCP_SYN_RECV", "TCP_FIN_WAIT1", "TCP_FIN_WAIT2", "TCP_TIME_WAIT", "TCP_CLOSE", "TCP_CLOSE_WAIT", "TCP_LAST_ACL", "TCP_LISTEN", "TCP_CLOSING"} +const TCP_LAST_ACL = TCP_LAST_ACK + +var TCP_STATE = []string{"TCP_UNKNOWN", "TCP_ESTABLISHED", "TCP_SYN_SENT", "TCP_SYN_RECV", "TCP_FIN_WAIT1", "TCP_FIN_WAIT2", "TCP_TIME_WAIT", "TCP_CLOSE", "TCP_CLOSE_WAIT", "TCP_LAST_ACK", "TCP_LISTEN", "TCP_CLOSING"} type NetAdapter struct { Name string diff --git a/typed_test.go b/typed_test.go new file mode 100644 index 0000000..752896c --- /dev/null +++ b/typed_test.go @@ -0,0 +1,12 @@ +package staros + +import "testing" + +func TestTCPStateLastACKSpellingAndCompatibilityAlias(t *testing.T) { + if TCP_STATE[TCP_LAST_ACK] != "TCP_LAST_ACK" { + t.Fatalf("unexpected LAST_ACK state string: %s", TCP_STATE[TCP_LAST_ACK]) + } + if TCP_LAST_ACL != TCP_LAST_ACK { + t.Fatalf("TCP_LAST_ACL should remain a compatibility alias") + } +}