feat: display download progress when installing models (#543)

renovate/github.com-imdario-mergo-1.x
Ettore Di Giacinto 1 year ago committed by GitHub
parent c9bbba4872
commit 84946e9275
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 54
      api/gallery.go
  2. 66
      pkg/gallery/models.go
  3. 8
      pkg/gallery/models_test.go

@ -10,10 +10,12 @@ import (
"os" "os"
"strings" "strings"
"sync" "sync"
"time"
"github.com/go-skynet/LocalAI/pkg/gallery" "github.com/go-skynet/LocalAI/pkg/gallery"
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/rs/zerolog/log"
"gopkg.in/yaml.v3" "gopkg.in/yaml.v3"
) )
@ -23,9 +25,12 @@ type galleryOp struct {
} }
type galleryOpStatus struct { type galleryOpStatus struct {
Error error `json:"error"` Error error `json:"error"`
Processed bool `json:"processed"` Processed bool `json:"processed"`
Message string `json:"message"` Message string `json:"message"`
Progress float64 `json:"progress"`
TotalFileSize string `json:"file_size"`
DownloadedFileSize string `json:"downloaded_size"`
} }
type galleryApplier struct { type galleryApplier struct {
@ -43,7 +48,7 @@ func newGalleryApplier(modelPath string) *galleryApplier {
} }
} }
func applyGallery(modelPath string, req ApplyGalleryModelRequest, cm *ConfigMerger) error { func applyGallery(modelPath string, req ApplyGalleryModelRequest, cm *ConfigMerger, downloadStatus func(string, string, string, float64)) error {
url, err := req.DecodeURL() url, err := req.DecodeURL()
if err != nil { if err != nil {
return err return err
@ -71,7 +76,7 @@ func applyGallery(modelPath string, req ApplyGalleryModelRequest, cm *ConfigMerg
config.Files = append(config.Files, req.AdditionalFiles...) config.Files = append(config.Files, req.AdditionalFiles...)
if err := gallery.Apply(modelPath, req.Name, &config, req.Overrides); err != nil { if err := gallery.Apply(modelPath, req.Name, &config, req.Overrides, downloadStatus); err != nil {
return err return err
} }
@ -99,23 +104,51 @@ func (g *galleryApplier) start(c context.Context, cm *ConfigMerger) {
case <-c.Done(): case <-c.Done():
return return
case op := <-g.C: case op := <-g.C:
g.updatestatus(op.id, &galleryOpStatus{Message: "processing"}) g.updatestatus(op.id, &galleryOpStatus{Message: "processing", Progress: 0})
updateError := func(e error) { updateError := func(e error) {
g.updatestatus(op.id, &galleryOpStatus{Error: e, Processed: true}) g.updatestatus(op.id, &galleryOpStatus{Error: e, Processed: true})
} }
if err := applyGallery(g.modelPath, op.req, cm); err != nil { if err := applyGallery(g.modelPath, op.req, cm, func(fileName string, current string, total string, percentage float64) {
g.updatestatus(op.id, &galleryOpStatus{Message: "processing", Progress: percentage, TotalFileSize: total, DownloadedFileSize: current})
displayDownload(fileName, current, total, percentage)
}); err != nil {
updateError(err) updateError(err)
continue continue
} }
g.updatestatus(op.id, &galleryOpStatus{Processed: true, Message: "completed"}) g.updatestatus(op.id, &galleryOpStatus{Processed: true, Message: "completed", Progress: 100})
} }
} }
}() }()
} }
var lastProgress time.Time = time.Now()
var startTime time.Time = time.Now()
func displayDownload(fileName string, current string, total string, percentage float64) {
currentTime := time.Now()
if currentTime.Sub(lastProgress) >= 5*time.Second {
lastProgress = currentTime
// calculate ETA based on percentage and elapsed time
var eta time.Duration
if percentage > 0 {
elapsed := currentTime.Sub(startTime)
eta = time.Duration(float64(elapsed)*(100/percentage) - float64(elapsed))
}
if total != "" {
log.Debug().Msgf("Downloading %s: %s/%s (%.2f%%) ETA: %s", fileName, current, total, percentage, eta)
} else {
log.Debug().Msgf("Downloading: %s", current)
}
}
}
func ApplyGalleryFromFile(modelPath, s string, cm *ConfigMerger) error { func ApplyGalleryFromFile(modelPath, s string, cm *ConfigMerger) error {
dat, err := os.ReadFile(s) dat, err := os.ReadFile(s)
if err != nil { if err != nil {
@ -128,13 +161,14 @@ func ApplyGalleryFromFile(modelPath, s string, cm *ConfigMerger) error {
} }
for _, r := range requests { for _, r := range requests {
if err := applyGallery(modelPath, r, cm); err != nil { if err := applyGallery(modelPath, r, cm, displayDownload); err != nil {
return err return err
} }
} }
return nil return nil
} }
func ApplyGalleryFromString(modelPath, s string, cm *ConfigMerger) error { func ApplyGalleryFromString(modelPath, s string, cm *ConfigMerger) error {
var requests []ApplyGalleryModelRequest var requests []ApplyGalleryModelRequest
err := json.Unmarshal([]byte(s), &requests) err := json.Unmarshal([]byte(s), &requests)
@ -143,7 +177,7 @@ func ApplyGalleryFromString(modelPath, s string, cm *ConfigMerger) error {
} }
for _, r := range requests { for _, r := range requests {
if err := applyGallery(modelPath, r, cm); err != nil { if err := applyGallery(modelPath, r, cm, displayDownload); err != nil {
return err return err
} }
} }

@ -3,10 +3,12 @@ package gallery
import ( import (
"crypto/sha256" "crypto/sha256"
"fmt" "fmt"
"hash"
"io" "io"
"net/http" "net/http"
"os" "os"
"path/filepath" "path/filepath"
"strconv"
"github.com/imdario/mergo" "github.com/imdario/mergo"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
@ -93,7 +95,7 @@ func verifyPath(path, basePath string) error {
return inTrustedRoot(c, basePath) return inTrustedRoot(c, basePath)
} }
func Apply(basePath, nameOverride string, config *Config, configOverrides map[string]interface{}) error { func Apply(basePath, nameOverride string, config *Config, configOverrides map[string]interface{}, downloadStatus func(string, string, string, float64)) error {
// Create base path if it doesn't exist // Create base path if it doesn't exist
err := os.MkdirAll(basePath, 0755) err := os.MkdirAll(basePath, 0755)
if err != nil { if err != nil {
@ -168,27 +170,25 @@ func Apply(basePath, nameOverride string, config *Config, configOverrides map[st
} }
defer outFile.Close() defer outFile.Close()
if file.SHA256 != "" { progress := &progressWriter{
log.Debug().Msgf("Download and verifying %q", file.Filename) fileName: file.Filename,
total: resp.ContentLength,
// Write file content and calculate SHA hash: sha256.New(),
hash := sha256.New() downloadStatus: downloadStatus,
_, err = io.Copy(io.MultiWriter(outFile, hash), resp.Body) }
if err != nil { _, err = io.Copy(io.MultiWriter(outFile, progress), resp.Body)
return fmt.Errorf("failed to write file %q: %v", file.Filename, err) if err != nil {
} return fmt.Errorf("failed to write file %q: %v", file.Filename, err)
}
if file.SHA256 != "" {
// Verify SHA // Verify SHA
calculatedSHA := fmt.Sprintf("%x", hash.Sum(nil)) calculatedSHA := fmt.Sprintf("%x", progress.hash.Sum(nil))
if calculatedSHA != file.SHA256 { if calculatedSHA != file.SHA256 {
return fmt.Errorf("SHA mismatch for file %q ( calculated: %s != metadata: %s )", file.Filename, calculatedSHA, file.SHA256) return fmt.Errorf("SHA mismatch for file %q ( calculated: %s != metadata: %s )", file.Filename, calculatedSHA, file.SHA256)
} }
} else { } else {
log.Debug().Msgf("SHA missing for %q. Skipping validation", file.Filename) log.Debug().Msgf("SHA missing for %q. Skipping validation", file.Filename)
_, err = io.Copy(outFile, resp.Body)
if err != nil {
return fmt.Errorf("failed to write file %q: %v", file.Filename, err)
}
} }
log.Debug().Msgf("File %q downloaded and verified", file.Filename) log.Debug().Msgf("File %q downloaded and verified", file.Filename)
@ -255,6 +255,42 @@ func Apply(basePath, nameOverride string, config *Config, configOverrides map[st
return nil return nil
} }
type progressWriter struct {
fileName string
total int64
written int64
downloadStatus func(string, string, string, float64)
hash hash.Hash
}
func (pw *progressWriter) Write(p []byte) (n int, err error) {
n, err = pw.hash.Write(p)
pw.written += int64(n)
if pw.total > 0 {
percentage := float64(pw.written) / float64(pw.total) * 100
//log.Debug().Msgf("Downloading %s: %s/%s (%.2f%%)", pw.fileName, formatBytes(pw.written), formatBytes(pw.total), percentage)
pw.downloadStatus(pw.fileName, formatBytes(pw.written), formatBytes(pw.total), percentage)
} else {
pw.downloadStatus(pw.fileName, formatBytes(pw.written), "", 0)
}
return
}
func formatBytes(bytes int64) string {
const unit = 1024
if bytes < unit {
return strconv.FormatInt(bytes, 10) + " B"
}
div, exp := int64(unit), 0
for n := bytes / unit; n >= unit; n /= unit {
div *= unit
exp++
}
return fmt.Sprintf("%.1f %ciB", float64(bytes)/float64(div), "KMGTPE"[exp])
}
func calculateSHA(filePath string) (string, error) { func calculateSHA(filePath string) (string, error) {
file, err := os.Open(filePath) file, err := os.Open(filePath)
if err != nil { if err != nil {

@ -19,7 +19,7 @@ var _ = Describe("Model test", func() {
c, err := ReadConfigFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml")) c, err := ReadConfigFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml"))
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
err = Apply(tempdir, "", c, map[string]interface{}{}) err = Apply(tempdir, "", c, map[string]interface{}{}, func(string, string, string, float64) {})
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "cerebras.yaml"} { for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "cerebras.yaml"} {
@ -45,7 +45,7 @@ var _ = Describe("Model test", func() {
c, err := ReadConfigFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml")) c, err := ReadConfigFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml"))
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
err = Apply(tempdir, "foo", c, map[string]interface{}{}) err = Apply(tempdir, "foo", c, map[string]interface{}{}, func(string, string, string, float64) {})
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "foo.yaml"} { for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "foo.yaml"} {
@ -61,7 +61,7 @@ var _ = Describe("Model test", func() {
c, err := ReadConfigFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml")) c, err := ReadConfigFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml"))
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
err = Apply(tempdir, "foo", c, map[string]interface{}{"backend": "foo"}) err = Apply(tempdir, "foo", c, map[string]interface{}{"backend": "foo"}, func(string, string, string, float64) {})
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "foo.yaml"} { for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "foo.yaml"} {
@ -87,7 +87,7 @@ var _ = Describe("Model test", func() {
c, err := ReadConfigFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml")) c, err := ReadConfigFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml"))
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
err = Apply(tempdir, "../../../foo", c, map[string]interface{}{}) err = Apply(tempdir, "../../../foo", c, map[string]interface{}{}, func(string, string, string, float64) {})
Expect(err).To(HaveOccurred()) Expect(err).To(HaveOccurred())
}) })
}) })

Loading…
Cancel
Save