diff --git a/agent-wdd/cmd/Download.go b/agent-wdd/cmd/Download.go index 10e9216..1dc7acb 100644 --- a/agent-wdd/cmd/Download.go +++ b/agent-wdd/cmd/Download.go @@ -3,8 +3,15 @@ package cmd import ( "agent-wdd/log" "agent-wdd/utils" + "context" + "fmt" + "net" + "net/http" + "net/url" + "time" "github.com/spf13/cobra" + "golang.org/x/net/proxy" ) // 示例:添加download子命令 @@ -14,13 +21,31 @@ func addDownloadSubcommands(cmd *cobra.Command) { Short: "使用代理下载 支持socks5代理 http代理", Args: cobra.ExactArgs(3), Run: func(cmd *cobra.Command, args []string) { - // 判定参数是否正确 + proxyURL := args[0] + fileURL := args[1] + destPath := args[2] + + log.Info("Downloading using proxy: %s -> from %s to %s\n", proxyURL, fileURL, destPath) + + // 创建带代理的HTTP客户端 + client, err := createProxyClient(proxyURL) + if err != nil { + log.Error("创建代理客户端失败: %v", err) + } + + // 执行下载 + downloadOk, resultLog := utils.DownloadFileWithClient(client, fileURL, destPath) + if !downloadOk { + log.Error("下载失败: %v", resultLog) + } else { + log.Info("文件下载完成") + } - log.Info("Downloading using proxy: %s -> from %s to %s\n", args[0], args[1], args[2]) }, } cmd.Run = func(cmd *cobra.Command, args []string) { + if len(args) == 0 { log.Error("请输入下载地址") return @@ -42,4 +67,55 @@ func addDownloadSubcommands(cmd *cobra.Command) { cmd.AddCommand(proxyCmd) } -// 根据需求补充其他子命令的添加函数... +// 创建带代理的HTTP客户端 +func createProxyClient(proxyURL string) (*http.Client, error) { + parsedURL, err := url.Parse(proxyURL) + if err != nil { + return nil, fmt.Errorf("无效的代理URL: %w", err) + } + + switch parsedURL.Scheme { + case "socks5": + return createSocks5Client(parsedURL) + case "http", "https": + return createHTTPClient(parsedURL), nil + default: + return nil, fmt.Errorf("不支持的代理协议: %s", parsedURL.Scheme) + } +} + +// 创建SOCKS5代理客户端 +func createSocks5Client(proxyURL *url.URL) (*http.Client, error) { + var auth *proxy.Auth + if proxyURL.User != nil { + password, _ := proxyURL.User.Password() + auth = &proxy.Auth{ + User: proxyURL.User.Username(), + Password: password, + } + } + + dialer, err := proxy.SOCKS5("tcp", proxyURL.Host, auth, proxy.Direct) + if err != nil { + return nil, fmt.Errorf("创建SOCKS5拨号器失败: %w", err) + } + + return &http.Client{ + Transport: &http.Transport{ + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + return dialer.Dial(network, addr) + }, + }, + Timeout: 10 * time.Second, + }, nil +} + +// 创建HTTP代理客户端 +func createHTTPClient(proxyURL *url.URL) *http.Client { + return &http.Client{ + Transport: &http.Transport{ + Proxy: http.ProxyURL(proxyURL), + }, + Timeout: 30 * time.Minute, + } +} diff --git a/agent-wdd/cmd/root.go b/agent-wdd/cmd/root.go index 947f316..5b6d48b 100644 --- a/agent-wdd/cmd/root.go +++ b/agent-wdd/cmd/root.go @@ -112,6 +112,7 @@ func Execute() { Use: "download", Short: "文件下载管理", } + addDownloadSubcommands(downloadCmd) helpCmd := &cobra.Command{ diff --git a/agent-wdd/go.mod b/agent-wdd/go.mod index 8a38581..ea229d3 100644 --- a/agent-wdd/go.mod +++ b/agent-wdd/go.mod @@ -25,8 +25,9 @@ require ( go.uber.org/atomic v1.9.0 // indirect go.uber.org/multierr v1.9.0 // indirect golang.org/x/exp v0.0.0-20230905200255-921286631fa9 // indirect - golang.org/x/sys v0.18.0 // indirect - golang.org/x/text v0.14.0 // indirect + golang.org/x/net v0.35.0 + golang.org/x/sys v0.30.0 // indirect + golang.org/x/text v0.22.0 // indirect gopkg.in/ini.v1 v1.67.0 // indirect - + ) diff --git a/agent-wdd/go.sum b/agent-wdd/go.sum index aba9163..22b16a1 100644 --- a/agent-wdd/go.sum +++ b/agent-wdd/go.sum @@ -63,10 +63,16 @@ go.uber.org/multierr v1.9.0 h1:7fIwc/ZtS0q++VgcfqFDxSBZVv/Xo49/SYnDFupUwlI= go.uber.org/multierr v1.9.0/go.mod h1:X2jQV1h+kxSjClGpnseKVIxpmcjrj7MNnI0bnlfKTVQ= golang.org/x/exp v0.0.0-20230905200255-921286631fa9 h1:GoHiUyI/Tp2nVkLI2mCxVkOjsbSXD66ic0XW0js0R9g= golang.org/x/exp v0.0.0-20230905200255-921286631fa9/go.mod h1:S2oDrQGGwySpoQPVqRShND87VCbxmc6bL1Yd2oYrm6k= +golang.org/x/net v0.35.0 h1:T5GQRQb2y08kTAByq9L4/bz8cipCdA8FbRTXewonqY8= +golang.org/x/net v0.35.0/go.mod h1:EglIi67kWsHKlRzzVMUD93VMSWGFOMSZgxFjparz1Qk= golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4= golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc= +golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= +golang.org/x/text v0.22.0 h1:bofq7m3/HAFvbF51jz3Q9wLg3jkvSPuiZu/pD1XwgtM= +golang.org/x/text v0.22.0/go.mod h1:YRoo4H8PVmsu+E3Ou7cqLVH8oXWIHVoX0jqUWALQhfY= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/agent-wdd/utils/DownloadUtils.go b/agent-wdd/utils/DownloadUtils.go index c8eea9b..599b586 100644 --- a/agent-wdd/utils/DownloadUtils.go +++ b/agent-wdd/utils/DownloadUtils.go @@ -5,6 +5,7 @@ import ( "io" "net/http" "os" + "strings" "time" ) @@ -14,35 +15,94 @@ func DownloadFile(url string, path string) (bool, string) { Timeout: 5 * time.Second, } - // 发送GET请求 + return DownloadFileWithClient(client, url, path) +} + +func DownloadFileWithClient(client *http.Client, url string, path string) (bool, string) { + return downloadWithProgress(client, url, path) +} + +// 带进度显示的下载函数 +func downloadWithProgress(client *http.Client, url, dest string) (bool, string) { + // 创建目标文件 + file, err := os.Create(dest) + if err != nil { + return false, fmt.Sprintf("创建文件失败: %w", err) + } + defer file.Close() + + // 发起请求 resp, err := client.Get(url) if err != nil { - return false, fmt.Sprintf("下载文件失败: %v", err) + return false, fmt.Sprintf("HTTP请求失败: %w", err) } defer resp.Body.Close() - // 检查响应状态码 if resp.StatusCode != http.StatusOK { - return false, fmt.Sprintf("下载文件失败,HTTP状态码: %d", resp.StatusCode) + return false, fmt.Sprintf("服务器返回错误状态码: %s", resp.Status) } - // 创建目标文件 - out, err := os.Create(path) - if err != nil { - return false, fmt.Sprintf("创建文件失败: %v", err) - } - defer out.Close() + // 获取文件大小 + size := resp.ContentLength + var downloaded int64 - // 将响应内容写入文件 - _, err = io.Copy(out, resp.Body) - if err != nil { - return false, fmt.Sprintf("写入文件失败: %v", err) + // 创建带进度跟踪的Reader + progressReader := &progressReader{ + Reader: resp.Body, + Reporter: func(r int64) { + downloaded += r + printProgress(downloaded, size) + }, } - // 检查文件是否存在 - if !FileExistAndNotNull(path) { - return false, fmt.Sprintf("文件下载失败: 文件为空 => %s", path) + // 执行拷贝 + if _, err := io.Copy(file, progressReader); err != nil { + return false, fmt.Sprintf("文件拷贝失败: %w", err) } - return true, fmt.Sprintf("文件下载成功: %s", path) + fmt.Print("\n") // 保持最后进度显示的完整性 + return true, fmt.Sprintf("文件下载成功: %s", dest) +} + +// 进度跟踪Reader +type progressReader struct { + io.Reader + Reporter func(r int64) +} + +func (pr *progressReader) Read(p []byte) (int, error) { + n, err := pr.Reader.Read(p) + if n > 0 { + pr.Reporter(int64(n)) + } + return n, err +} + +// 打印进度信息 +func printProgress(downloaded, total int64) { + const barLength = 40 + percent := float64(downloaded) / float64(total) * 100 + + // 生成进度条 + filled := int(barLength * downloaded / total) + bar := fmt.Sprintf("[%s%s]", + strings.Repeat("=", filled), + strings.Repeat(" ", barLength-filled)) + + // 格式化为人类可读大小 + humanSize := func(bytes int64) string { + const unit = 1024 + if bytes < unit { + return fmt.Sprintf("%d B", bytes) + } + div, exp := int64(unit), 0 + for n := bytes / unit; n >= unit; n /= unit { + div *= unit + exp++ + } + return fmt.Sprintf("%.1f %ciB", float64(bytes)/float64(div), "KMGTPE"[exp]) + } + + fmt.Printf("\r%-45s %6.2f%% %s/%s", bar, percent, + humanSize(downloaded), humanSize(total)) }