From 84946e92756af9e465d39a8657de263c82a2683b Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Thu, 8 Jun 2023 21:33:18 +0200 Subject: [PATCH] feat: display download progress when installing models (#543) --- api/gallery.go | 54 +++++++++++++++++++++++++------ pkg/gallery/models.go | 66 +++++++++++++++++++++++++++++--------- pkg/gallery/models_test.go | 8 ++--- 3 files changed, 99 insertions(+), 29 deletions(-) diff --git a/api/gallery.go b/api/gallery.go index b5b74b0..a9a8722 100644 --- a/api/gallery.go +++ b/api/gallery.go @@ -10,10 +10,12 @@ import ( "os" "strings" "sync" + "time" "github.com/go-skynet/LocalAI/pkg/gallery" "github.com/gofiber/fiber/v2" "github.com/google/uuid" + "github.com/rs/zerolog/log" "gopkg.in/yaml.v3" ) @@ -23,9 +25,12 @@ type galleryOp struct { } type galleryOpStatus struct { - Error error `json:"error"` - Processed bool `json:"processed"` - Message string `json:"message"` + Error error `json:"error"` + Processed bool `json:"processed"` + Message string `json:"message"` + Progress float64 `json:"progress"` + TotalFileSize string `json:"file_size"` + DownloadedFileSize string `json:"downloaded_size"` } type galleryApplier struct { @@ -43,7 +48,7 @@ func newGalleryApplier(modelPath string) *galleryApplier { } } -func applyGallery(modelPath string, req ApplyGalleryModelRequest, cm *ConfigMerger) error { +func applyGallery(modelPath string, req ApplyGalleryModelRequest, cm *ConfigMerger, downloadStatus func(string, string, string, float64)) error { url, err := req.DecodeURL() if err != nil { return err @@ -71,7 +76,7 @@ func applyGallery(modelPath string, req ApplyGalleryModelRequest, cm *ConfigMerg config.Files = append(config.Files, req.AdditionalFiles...) - if err := gallery.Apply(modelPath, req.Name, &config, req.Overrides); err != nil { + if err := gallery.Apply(modelPath, req.Name, &config, req.Overrides, downloadStatus); err != nil { return err } @@ -99,23 +104,51 @@ func (g *galleryApplier) start(c context.Context, cm *ConfigMerger) { case <-c.Done(): return case op := <-g.C: - g.updatestatus(op.id, &galleryOpStatus{Message: "processing"}) + g.updatestatus(op.id, &galleryOpStatus{Message: "processing", Progress: 0}) updateError := func(e error) { g.updatestatus(op.id, &galleryOpStatus{Error: e, Processed: true}) } - if err := applyGallery(g.modelPath, op.req, cm); err != nil { + if err := applyGallery(g.modelPath, op.req, cm, func(fileName string, current string, total string, percentage float64) { + g.updatestatus(op.id, &galleryOpStatus{Message: "processing", Progress: percentage, TotalFileSize: total, DownloadedFileSize: current}) + displayDownload(fileName, current, total, percentage) + }); err != nil { updateError(err) continue } - g.updatestatus(op.id, &galleryOpStatus{Processed: true, Message: "completed"}) + g.updatestatus(op.id, &galleryOpStatus{Processed: true, Message: "completed", Progress: 100}) } } }() } +var lastProgress time.Time = time.Now() +var startTime time.Time = time.Now() + +func displayDownload(fileName string, current string, total string, percentage float64) { + currentTime := time.Now() + + if currentTime.Sub(lastProgress) >= 5*time.Second { + + lastProgress = currentTime + + // calculate ETA based on percentage and elapsed time + var eta time.Duration + if percentage > 0 { + elapsed := currentTime.Sub(startTime) + eta = time.Duration(float64(elapsed)*(100/percentage) - float64(elapsed)) + } + + if total != "" { + log.Debug().Msgf("Downloading %s: %s/%s (%.2f%%) ETA: %s", fileName, current, total, percentage, eta) + } else { + log.Debug().Msgf("Downloading: %s", current) + } + } +} + func ApplyGalleryFromFile(modelPath, s string, cm *ConfigMerger) error { dat, err := os.ReadFile(s) if err != nil { @@ -128,13 +161,14 @@ func ApplyGalleryFromFile(modelPath, s string, cm *ConfigMerger) error { } for _, r := range requests { - if err := applyGallery(modelPath, r, cm); err != nil { + if err := applyGallery(modelPath, r, cm, displayDownload); err != nil { return err } } return nil } + func ApplyGalleryFromString(modelPath, s string, cm *ConfigMerger) error { var requests []ApplyGalleryModelRequest err := json.Unmarshal([]byte(s), &requests) @@ -143,7 +177,7 @@ func ApplyGalleryFromString(modelPath, s string, cm *ConfigMerger) error { } for _, r := range requests { - if err := applyGallery(modelPath, r, cm); err != nil { + if err := applyGallery(modelPath, r, cm, displayDownload); err != nil { return err } } diff --git a/pkg/gallery/models.go b/pkg/gallery/models.go index f4f86ae..14a7d6a 100644 --- a/pkg/gallery/models.go +++ b/pkg/gallery/models.go @@ -3,10 +3,12 @@ package gallery import ( "crypto/sha256" "fmt" + "hash" "io" "net/http" "os" "path/filepath" + "strconv" "github.com/imdario/mergo" "github.com/rs/zerolog/log" @@ -93,7 +95,7 @@ func verifyPath(path, basePath string) error { return inTrustedRoot(c, basePath) } -func Apply(basePath, nameOverride string, config *Config, configOverrides map[string]interface{}) error { +func Apply(basePath, nameOverride string, config *Config, configOverrides map[string]interface{}, downloadStatus func(string, string, string, float64)) error { // Create base path if it doesn't exist err := os.MkdirAll(basePath, 0755) if err != nil { @@ -168,27 +170,25 @@ func Apply(basePath, nameOverride string, config *Config, configOverrides map[st } defer outFile.Close() - if file.SHA256 != "" { - log.Debug().Msgf("Download and verifying %q", file.Filename) - - // Write file content and calculate SHA - hash := sha256.New() - _, err = io.Copy(io.MultiWriter(outFile, hash), resp.Body) - if err != nil { - return fmt.Errorf("failed to write file %q: %v", file.Filename, err) - } + progress := &progressWriter{ + fileName: file.Filename, + total: resp.ContentLength, + hash: sha256.New(), + downloadStatus: downloadStatus, + } + _, err = io.Copy(io.MultiWriter(outFile, progress), resp.Body) + if err != nil { + return fmt.Errorf("failed to write file %q: %v", file.Filename, err) + } + if file.SHA256 != "" { // Verify SHA - calculatedSHA := fmt.Sprintf("%x", hash.Sum(nil)) + calculatedSHA := fmt.Sprintf("%x", progress.hash.Sum(nil)) if calculatedSHA != file.SHA256 { return fmt.Errorf("SHA mismatch for file %q ( calculated: %s != metadata: %s )", file.Filename, calculatedSHA, file.SHA256) } } else { log.Debug().Msgf("SHA missing for %q. Skipping validation", file.Filename) - _, err = io.Copy(outFile, resp.Body) - if err != nil { - return fmt.Errorf("failed to write file %q: %v", file.Filename, err) - } } log.Debug().Msgf("File %q downloaded and verified", file.Filename) @@ -255,6 +255,42 @@ func Apply(basePath, nameOverride string, config *Config, configOverrides map[st return nil } +type progressWriter struct { + fileName string + total int64 + written int64 + downloadStatus func(string, string, string, float64) + hash hash.Hash +} + +func (pw *progressWriter) Write(p []byte) (n int, err error) { + n, err = pw.hash.Write(p) + pw.written += int64(n) + + if pw.total > 0 { + percentage := float64(pw.written) / float64(pw.total) * 100 + //log.Debug().Msgf("Downloading %s: %s/%s (%.2f%%)", pw.fileName, formatBytes(pw.written), formatBytes(pw.total), percentage) + pw.downloadStatus(pw.fileName, formatBytes(pw.written), formatBytes(pw.total), percentage) + } else { + pw.downloadStatus(pw.fileName, formatBytes(pw.written), "", 0) + } + + return +} + +func formatBytes(bytes int64) string { + const unit = 1024 + if bytes < unit { + return strconv.FormatInt(bytes, 10) + " B" + } + div, exp := int64(unit), 0 + for n := bytes / unit; n >= unit; n /= unit { + div *= unit + exp++ + } + return fmt.Sprintf("%.1f %ciB", float64(bytes)/float64(div), "KMGTPE"[exp]) +} + func calculateSHA(filePath string) (string, error) { file, err := os.Open(filePath) if err != nil { diff --git a/pkg/gallery/models_test.go b/pkg/gallery/models_test.go index f0e580e..343bf6a 100644 --- a/pkg/gallery/models_test.go +++ b/pkg/gallery/models_test.go @@ -19,7 +19,7 @@ var _ = Describe("Model test", func() { c, err := ReadConfigFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml")) Expect(err).ToNot(HaveOccurred()) - err = Apply(tempdir, "", c, map[string]interface{}{}) + err = Apply(tempdir, "", c, map[string]interface{}{}, func(string, string, string, float64) {}) Expect(err).ToNot(HaveOccurred()) for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "cerebras.yaml"} { @@ -45,7 +45,7 @@ var _ = Describe("Model test", func() { c, err := ReadConfigFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml")) Expect(err).ToNot(HaveOccurred()) - err = Apply(tempdir, "foo", c, map[string]interface{}{}) + err = Apply(tempdir, "foo", c, map[string]interface{}{}, func(string, string, string, float64) {}) Expect(err).ToNot(HaveOccurred()) for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "foo.yaml"} { @@ -61,7 +61,7 @@ var _ = Describe("Model test", func() { c, err := ReadConfigFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml")) Expect(err).ToNot(HaveOccurred()) - err = Apply(tempdir, "foo", c, map[string]interface{}{"backend": "foo"}) + err = Apply(tempdir, "foo", c, map[string]interface{}{"backend": "foo"}, func(string, string, string, float64) {}) Expect(err).ToNot(HaveOccurred()) for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "foo.yaml"} { @@ -87,7 +87,7 @@ var _ = Describe("Model test", func() { c, err := ReadConfigFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml")) Expect(err).ToNot(HaveOccurred()) - err = Apply(tempdir, "../../../foo", c, map[string]interface{}{}) + err = Apply(tempdir, "../../../foo", c, map[string]interface{}{}, func(string, string, string, float64) {}) Expect(err).To(HaveOccurred()) }) })