feat: add /models/apply endpoint to prepare models (#286)
parent
5617e50ebc
commit
cc9aa9eb3f
@ -0,0 +1,146 @@ |
|||||||
|
package api |
||||||
|
|
||||||
|
import ( |
||||||
|
"context" |
||||||
|
"fmt" |
||||||
|
"io/ioutil" |
||||||
|
"net/http" |
||||||
|
"sync" |
||||||
|
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/gallery" |
||||||
|
"github.com/gofiber/fiber/v2" |
||||||
|
"github.com/google/uuid" |
||||||
|
"gopkg.in/yaml.v3" |
||||||
|
) |
||||||
|
|
||||||
|
type galleryOp struct { |
||||||
|
req ApplyGalleryModelRequest |
||||||
|
id string |
||||||
|
} |
||||||
|
|
||||||
|
type galleryOpStatus struct { |
||||||
|
Error error `json:"error"` |
||||||
|
Processed bool `json:"processed"` |
||||||
|
Message string `json:"message"` |
||||||
|
} |
||||||
|
|
||||||
|
type galleryApplier struct { |
||||||
|
modelPath string |
||||||
|
sync.Mutex |
||||||
|
C chan galleryOp |
||||||
|
statuses map[string]*galleryOpStatus |
||||||
|
} |
||||||
|
|
||||||
|
func newGalleryApplier(modelPath string) *galleryApplier { |
||||||
|
return &galleryApplier{ |
||||||
|
modelPath: modelPath, |
||||||
|
C: make(chan galleryOp), |
||||||
|
statuses: make(map[string]*galleryOpStatus), |
||||||
|
} |
||||||
|
} |
||||||
|
func (g *galleryApplier) updatestatus(s string, op *galleryOpStatus) { |
||||||
|
g.Lock() |
||||||
|
defer g.Unlock() |
||||||
|
g.statuses[s] = op |
||||||
|
} |
||||||
|
|
||||||
|
func (g *galleryApplier) getstatus(s string) *galleryOpStatus { |
||||||
|
g.Lock() |
||||||
|
defer g.Unlock() |
||||||
|
|
||||||
|
return g.statuses[s] |
||||||
|
} |
||||||
|
|
||||||
|
func (g *galleryApplier) start(c context.Context, cm *ConfigMerger) { |
||||||
|
go func() { |
||||||
|
for { |
||||||
|
select { |
||||||
|
case <-c.Done(): |
||||||
|
return |
||||||
|
case op := <-g.C: |
||||||
|
g.updatestatus(op.id, &galleryOpStatus{Message: "processing"}) |
||||||
|
|
||||||
|
updateError := func(e error) { |
||||||
|
g.updatestatus(op.id, &galleryOpStatus{Error: e, Processed: true}) |
||||||
|
} |
||||||
|
// Send a GET request to the URL
|
||||||
|
response, err := http.Get(op.req.URL) |
||||||
|
if err != nil { |
||||||
|
updateError(err) |
||||||
|
continue |
||||||
|
} |
||||||
|
defer response.Body.Close() |
||||||
|
|
||||||
|
// Read the response body
|
||||||
|
body, err := ioutil.ReadAll(response.Body) |
||||||
|
if err != nil { |
||||||
|
updateError(err) |
||||||
|
continue |
||||||
|
} |
||||||
|
|
||||||
|
// Unmarshal YAML data into a Config struct
|
||||||
|
var config gallery.Config |
||||||
|
err = yaml.Unmarshal(body, &config) |
||||||
|
if err != nil { |
||||||
|
updateError(fmt.Errorf("failed to unmarshal YAML: %v", err)) |
||||||
|
continue |
||||||
|
} |
||||||
|
|
||||||
|
if err := gallery.Apply(g.modelPath, op.req.Name, &config); err != nil { |
||||||
|
updateError(err) |
||||||
|
continue |
||||||
|
} |
||||||
|
|
||||||
|
// Reload models
|
||||||
|
if err := cm.LoadConfigs(g.modelPath); err != nil { |
||||||
|
updateError(err) |
||||||
|
continue |
||||||
|
} |
||||||
|
|
||||||
|
g.updatestatus(op.id, &galleryOpStatus{Processed: true, Message: "completed"}) |
||||||
|
} |
||||||
|
} |
||||||
|
}() |
||||||
|
} |
||||||
|
|
||||||
|
// endpoints
|
||||||
|
|
||||||
|
type ApplyGalleryModelRequest struct { |
||||||
|
URL string `json:"url"` |
||||||
|
Name string `json:"name"` |
||||||
|
} |
||||||
|
|
||||||
|
func getOpStatus(g *galleryApplier) func(c *fiber.Ctx) error { |
||||||
|
return func(c *fiber.Ctx) error { |
||||||
|
|
||||||
|
status := g.getstatus(c.Params("uid")) |
||||||
|
if status == nil { |
||||||
|
return fmt.Errorf("could not find any status for ID") |
||||||
|
} |
||||||
|
|
||||||
|
return c.JSON(status) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func applyModelGallery(modelPath string, cm *ConfigMerger, g chan galleryOp) func(c *fiber.Ctx) error { |
||||||
|
return func(c *fiber.Ctx) error { |
||||||
|
input := new(ApplyGalleryModelRequest) |
||||||
|
// Get input data from the request body
|
||||||
|
if err := c.BodyParser(input); err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
|
||||||
|
uuid, err := uuid.NewUUID() |
||||||
|
if err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
g <- galleryOp{ |
||||||
|
req: *input, |
||||||
|
id: uuid.String(), |
||||||
|
} |
||||||
|
return c.JSON(struct { |
||||||
|
ID string `json:"uid"` |
||||||
|
StatusURL string `json:"status"` |
||||||
|
}{ID: uuid.String(), StatusURL: c.BaseURL() + "/models/jobs/" + uuid.String()}) |
||||||
|
} |
||||||
|
} |
@ -0,0 +1,13 @@ |
|||||||
|
package gallery_test |
||||||
|
|
||||||
|
import ( |
||||||
|
"testing" |
||||||
|
|
||||||
|
. "github.com/onsi/ginkgo/v2" |
||||||
|
. "github.com/onsi/gomega" |
||||||
|
) |
||||||
|
|
||||||
|
func TestGallery(t *testing.T) { |
||||||
|
RegisterFailHandler(Fail) |
||||||
|
RunSpecs(t, "Gallery test suite") |
||||||
|
} |
@ -0,0 +1,237 @@ |
|||||||
|
package gallery |
||||||
|
|
||||||
|
import ( |
||||||
|
"crypto/sha256" |
||||||
|
"fmt" |
||||||
|
"io" |
||||||
|
"net/http" |
||||||
|
"os" |
||||||
|
"path/filepath" |
||||||
|
|
||||||
|
"github.com/rs/zerolog/log" |
||||||
|
"gopkg.in/yaml.v2" |
||||||
|
) |
||||||
|
|
||||||
|
/* |
||||||
|
|
||||||
|
description: | |
||||||
|
foo |
||||||
|
license: "" |
||||||
|
|
||||||
|
urls: |
||||||
|
- |
||||||
|
- |
||||||
|
|
||||||
|
name: "bar" |
||||||
|
|
||||||
|
config_file: | |
||||||
|
# Note, name will be injected. or generated by the alias wanted by the user |
||||||
|
threads: 14 |
||||||
|
|
||||||
|
files: |
||||||
|
- filename: "" |
||||||
|
sha: "" |
||||||
|
uri: "" |
||||||
|
|
||||||
|
prompt_templates: |
||||||
|
- name: "" |
||||||
|
content: "" |
||||||
|
|
||||||
|
*/ |
||||||
|
|
||||||
|
type Config struct { |
||||||
|
Description string `yaml:"description"` |
||||||
|
License string `yaml:"license"` |
||||||
|
URLs []string `yaml:"urls"` |
||||||
|
Name string `yaml:"name"` |
||||||
|
ConfigFile string `yaml:"config_file"` |
||||||
|
Files []File `yaml:"files"` |
||||||
|
PromptTemplates []PromptTemplate `yaml:"prompt_templates"` |
||||||
|
} |
||||||
|
|
||||||
|
type File struct { |
||||||
|
Filename string `yaml:"filename"` |
||||||
|
SHA256 string `yaml:"sha256"` |
||||||
|
URI string `yaml:"uri"` |
||||||
|
} |
||||||
|
|
||||||
|
type PromptTemplate struct { |
||||||
|
Name string `yaml:"name"` |
||||||
|
Content string `yaml:"content"` |
||||||
|
} |
||||||
|
|
||||||
|
func ReadConfigFile(filePath string) (*Config, error) { |
||||||
|
// Read the YAML file
|
||||||
|
yamlFile, err := os.ReadFile(filePath) |
||||||
|
if err != nil { |
||||||
|
return nil, fmt.Errorf("failed to read YAML file: %v", err) |
||||||
|
} |
||||||
|
|
||||||
|
// Unmarshal YAML data into a Config struct
|
||||||
|
var config Config |
||||||
|
err = yaml.Unmarshal(yamlFile, &config) |
||||||
|
if err != nil { |
||||||
|
return nil, fmt.Errorf("failed to unmarshal YAML: %v", err) |
||||||
|
} |
||||||
|
|
||||||
|
return &config, nil |
||||||
|
} |
||||||
|
|
||||||
|
func Apply(basePath, nameOverride string, config *Config) error { |
||||||
|
// Create base path if it doesn't exist
|
||||||
|
err := os.MkdirAll(basePath, 0755) |
||||||
|
if err != nil { |
||||||
|
return fmt.Errorf("failed to create base path: %v", err) |
||||||
|
} |
||||||
|
|
||||||
|
// Download files and verify their SHA
|
||||||
|
for _, file := range config.Files { |
||||||
|
log.Debug().Msgf("Checking %q exists and matches SHA", file.Filename) |
||||||
|
|
||||||
|
// Create file path
|
||||||
|
filePath := filepath.Join(basePath, file.Filename) |
||||||
|
|
||||||
|
// Check if the file already exists
|
||||||
|
_, err := os.Stat(filePath) |
||||||
|
if err == nil { |
||||||
|
// File exists, check SHA
|
||||||
|
if file.SHA256 != "" { |
||||||
|
// Verify SHA
|
||||||
|
calculatedSHA, err := calculateSHA(filePath) |
||||||
|
if err != nil { |
||||||
|
return fmt.Errorf("failed to calculate SHA for file %q: %v", file.Filename, err) |
||||||
|
} |
||||||
|
if calculatedSHA == file.SHA256 { |
||||||
|
// SHA matches, skip downloading
|
||||||
|
log.Debug().Msgf("File %q already exists and matches the SHA. Skipping download", file.Filename) |
||||||
|
continue |
||||||
|
} |
||||||
|
// SHA doesn't match, delete the file and download again
|
||||||
|
err = os.Remove(filePath) |
||||||
|
if err != nil { |
||||||
|
return fmt.Errorf("failed to remove existing file %q: %v", file.Filename, err) |
||||||
|
} |
||||||
|
log.Debug().Msgf("Removed %q (SHA doesn't match)", filePath) |
||||||
|
|
||||||
|
} else { |
||||||
|
// SHA is missing, skip downloading
|
||||||
|
log.Debug().Msgf("File %q already exists. Skipping download", file.Filename) |
||||||
|
continue |
||||||
|
} |
||||||
|
} else if !os.IsNotExist(err) { |
||||||
|
// Error occurred while checking file existence
|
||||||
|
return fmt.Errorf("failed to check file %q existence: %v", file.Filename, err) |
||||||
|
} |
||||||
|
|
||||||
|
log.Debug().Msgf("Downloading %q", file.URI) |
||||||
|
|
||||||
|
// Download file
|
||||||
|
resp, err := http.Get(file.URI) |
||||||
|
if err != nil { |
||||||
|
return fmt.Errorf("failed to download file %q: %v", file.Filename, err) |
||||||
|
} |
||||||
|
defer resp.Body.Close() |
||||||
|
|
||||||
|
// Create parent directory
|
||||||
|
err = os.MkdirAll(filepath.Dir(filePath), 0755) |
||||||
|
if err != nil { |
||||||
|
return fmt.Errorf("failed to create parent directory for file %q: %v", file.Filename, err) |
||||||
|
} |
||||||
|
|
||||||
|
// Create and write file content
|
||||||
|
outFile, err := os.Create(filePath) |
||||||
|
if err != nil { |
||||||
|
return fmt.Errorf("failed to create file %q: %v", file.Filename, err) |
||||||
|
} |
||||||
|
defer outFile.Close() |
||||||
|
|
||||||
|
if file.SHA256 != "" { |
||||||
|
log.Debug().Msgf("Download and verifying %q", file.Filename) |
||||||
|
|
||||||
|
// Write file content and calculate SHA
|
||||||
|
hash := sha256.New() |
||||||
|
_, err = io.Copy(io.MultiWriter(outFile, hash), resp.Body) |
||||||
|
if err != nil { |
||||||
|
return fmt.Errorf("failed to write file %q: %v", file.Filename, err) |
||||||
|
} |
||||||
|
|
||||||
|
// Verify SHA
|
||||||
|
calculatedSHA := fmt.Sprintf("%x", hash.Sum(nil)) |
||||||
|
if calculatedSHA != file.SHA256 { |
||||||
|
return fmt.Errorf("SHA mismatch for file %q ( calculated: %s != metadata: %s )", file.Filename, calculatedSHA, file.SHA256) |
||||||
|
} |
||||||
|
} else { |
||||||
|
log.Debug().Msgf("SHA missing for %q. Skipping validation", file.Filename) |
||||||
|
_, err = io.Copy(outFile, resp.Body) |
||||||
|
if err != nil { |
||||||
|
return fmt.Errorf("failed to write file %q: %v", file.Filename, err) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
log.Debug().Msgf("File %q downloaded and verified", file.Filename) |
||||||
|
} |
||||||
|
|
||||||
|
// Write prompt template contents to separate files
|
||||||
|
for _, template := range config.PromptTemplates { |
||||||
|
// Create file path
|
||||||
|
filePath := filepath.Join(basePath, template.Name+".tmpl") |
||||||
|
|
||||||
|
// Create parent directory
|
||||||
|
err := os.MkdirAll(filepath.Dir(filePath), 0755) |
||||||
|
if err != nil { |
||||||
|
return fmt.Errorf("failed to create parent directory for prompt template %q: %v", template.Name, err) |
||||||
|
} |
||||||
|
// Create and write file content
|
||||||
|
err = os.WriteFile(filePath, []byte(template.Content), 0644) |
||||||
|
if err != nil { |
||||||
|
return fmt.Errorf("failed to write prompt template %q: %v", template.Name, err) |
||||||
|
} |
||||||
|
|
||||||
|
log.Debug().Msgf("Prompt template %q written", template.Name) |
||||||
|
} |
||||||
|
|
||||||
|
name := config.Name |
||||||
|
if nameOverride != "" { |
||||||
|
name = nameOverride |
||||||
|
} |
||||||
|
|
||||||
|
configFilePath := filepath.Join(basePath, name+".yaml") |
||||||
|
|
||||||
|
// Read and update config file as map[string]interface{}
|
||||||
|
configMap := make(map[string]interface{}) |
||||||
|
err = yaml.Unmarshal([]byte(config.ConfigFile), &configMap) |
||||||
|
if err != nil { |
||||||
|
return fmt.Errorf("failed to unmarshal config YAML: %v", err) |
||||||
|
} |
||||||
|
|
||||||
|
configMap["name"] = name |
||||||
|
|
||||||
|
// Write updated config file
|
||||||
|
updatedConfigYAML, err := yaml.Marshal(configMap) |
||||||
|
if err != nil { |
||||||
|
return fmt.Errorf("failed to marshal updated config YAML: %v", err) |
||||||
|
} |
||||||
|
|
||||||
|
err = os.WriteFile(configFilePath, updatedConfigYAML, 0644) |
||||||
|
if err != nil { |
||||||
|
return fmt.Errorf("failed to write updated config file: %v", err) |
||||||
|
} |
||||||
|
|
||||||
|
log.Debug().Msgf("Written config file %s", configFilePath) |
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
func calculateSHA(filePath string) (string, error) { |
||||||
|
file, err := os.Open(filePath) |
||||||
|
if err != nil { |
||||||
|
return "", err |
||||||
|
} |
||||||
|
defer file.Close() |
||||||
|
|
||||||
|
hash := sha256.New() |
||||||
|
if _, err := io.Copy(hash, file); err != nil { |
||||||
|
return "", err |
||||||
|
} |
||||||
|
|
||||||
|
return fmt.Sprintf("%x", hash.Sum(nil)), nil |
||||||
|
} |
@ -0,0 +1,30 @@ |
|||||||
|
package gallery_test |
||||||
|
|
||||||
|
import ( |
||||||
|
"os" |
||||||
|
"path/filepath" |
||||||
|
|
||||||
|
. "github.com/go-skynet/LocalAI/pkg/gallery" |
||||||
|
. "github.com/onsi/ginkgo/v2" |
||||||
|
. "github.com/onsi/gomega" |
||||||
|
) |
||||||
|
|
||||||
|
var _ = Describe("Model test", func() { |
||||||
|
Context("Downloading", func() { |
||||||
|
It("applies model correctly", func() { |
||||||
|
tempdir, err := os.MkdirTemp("", "test") |
||||||
|
Expect(err).ToNot(HaveOccurred()) |
||||||
|
defer os.RemoveAll(tempdir) |
||||||
|
c, err := ReadConfigFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml")) |
||||||
|
Expect(err).ToNot(HaveOccurred()) |
||||||
|
|
||||||
|
err = Apply(tempdir, "", c) |
||||||
|
Expect(err).ToNot(HaveOccurred()) |
||||||
|
|
||||||
|
for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "cerebras.yaml"} { |
||||||
|
_, err = os.Stat(filepath.Join(tempdir, f)) |
||||||
|
Expect(err).ToNot(HaveOccurred()) |
||||||
|
} |
||||||
|
}) |
||||||
|
}) |
||||||
|
}) |
@ -0,0 +1,40 @@ |
|||||||
|
name: "cerebras" |
||||||
|
description: | |
||||||
|
cerebras |
||||||
|
license: "Apache 2.0" |
||||||
|
|
||||||
|
config_file: | |
||||||
|
parameters: |
||||||
|
model: cerebras |
||||||
|
top_k: 80 |
||||||
|
temperature: 0.2 |
||||||
|
top_p: 0.7 |
||||||
|
context_size: 1024 |
||||||
|
stopwords: |
||||||
|
- "HUMAN:" |
||||||
|
- "GPT:" |
||||||
|
roles: |
||||||
|
user: "" |
||||||
|
system: "" |
||||||
|
template: |
||||||
|
completion: "cerebras-completion" |
||||||
|
chat: cerebras-chat |
||||||
|
|
||||||
|
files: |
||||||
|
- filename: "cerebras" |
||||||
|
sha256: "c947051ae4dba9530ca55d923a7a484acd65664c8633462c8ccd4bb7848f2c65" |
||||||
|
uri: "https://huggingface.co/concedo/cerebras-111M-ggml/resolve/main/cerebras-111m-q4_2.bin" |
||||||
|
|
||||||
|
prompt_templates: |
||||||
|
- name: "cerebras-completion" |
||||||
|
content: | |
||||||
|
Complete the prompt |
||||||
|
### Prompt: |
||||||
|
{{.Input}} |
||||||
|
### Response: |
||||||
|
- name: "cerebras-chat" |
||||||
|
content: | |
||||||
|
The prompt below is a question to answer, a task to complete, or a conversation to respond to; decide which and write an appropriate response. |
||||||
|
### Prompt: |
||||||
|
{{.Input}} |
||||||
|
### Response: |
Loading…
Reference in new issue