You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
238 lines
6.1 KiB
238 lines
6.1 KiB
2 years ago
|
package gallery
|
||
|
|
||
|
import (
|
||
|
"crypto/sha256"
|
||
|
"fmt"
|
||
|
"io"
|
||
|
"net/http"
|
||
|
"os"
|
||
|
"path/filepath"
|
||
|
|
||
|
"github.com/rs/zerolog/log"
|
||
|
"gopkg.in/yaml.v2"
|
||
|
)
|
||
|
|
||
|
/*
|
||
|
|
||
|
description: |
|
||
|
foo
|
||
|
license: ""
|
||
|
|
||
|
urls:
|
||
|
-
|
||
|
-
|
||
|
|
||
|
name: "bar"
|
||
|
|
||
|
config_file: |
|
||
|
# Note, name will be injected. or generated by the alias wanted by the user
|
||
|
threads: 14
|
||
|
|
||
|
files:
|
||
|
- filename: ""
|
||
|
sha: ""
|
||
|
uri: ""
|
||
|
|
||
|
prompt_templates:
|
||
|
- name: ""
|
||
|
content: ""
|
||
|
|
||
|
*/
|
||
|
|
||
|
type Config struct {
|
||
|
Description string `yaml:"description"`
|
||
|
License string `yaml:"license"`
|
||
|
URLs []string `yaml:"urls"`
|
||
|
Name string `yaml:"name"`
|
||
|
ConfigFile string `yaml:"config_file"`
|
||
|
Files []File `yaml:"files"`
|
||
|
PromptTemplates []PromptTemplate `yaml:"prompt_templates"`
|
||
|
}
|
||
|
|
||
|
type File struct {
|
||
|
Filename string `yaml:"filename"`
|
||
|
SHA256 string `yaml:"sha256"`
|
||
|
URI string `yaml:"uri"`
|
||
|
}
|
||
|
|
||
|
type PromptTemplate struct {
|
||
|
Name string `yaml:"name"`
|
||
|
Content string `yaml:"content"`
|
||
|
}
|
||
|
|
||
|
func ReadConfigFile(filePath string) (*Config, error) {
|
||
|
// Read the YAML file
|
||
|
yamlFile, err := os.ReadFile(filePath)
|
||
|
if err != nil {
|
||
|
return nil, fmt.Errorf("failed to read YAML file: %v", err)
|
||
|
}
|
||
|
|
||
|
// Unmarshal YAML data into a Config struct
|
||
|
var config Config
|
||
|
err = yaml.Unmarshal(yamlFile, &config)
|
||
|
if err != nil {
|
||
|
return nil, fmt.Errorf("failed to unmarshal YAML: %v", err)
|
||
|
}
|
||
|
|
||
|
return &config, nil
|
||
|
}
|
||
|
|
||
|
func Apply(basePath, nameOverride string, config *Config) error {
|
||
|
// Create base path if it doesn't exist
|
||
|
err := os.MkdirAll(basePath, 0755)
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("failed to create base path: %v", err)
|
||
|
}
|
||
|
|
||
|
// Download files and verify their SHA
|
||
|
for _, file := range config.Files {
|
||
|
log.Debug().Msgf("Checking %q exists and matches SHA", file.Filename)
|
||
|
|
||
|
// Create file path
|
||
|
filePath := filepath.Join(basePath, file.Filename)
|
||
|
|
||
|
// Check if the file already exists
|
||
|
_, err := os.Stat(filePath)
|
||
|
if err == nil {
|
||
|
// File exists, check SHA
|
||
|
if file.SHA256 != "" {
|
||
|
// Verify SHA
|
||
|
calculatedSHA, err := calculateSHA(filePath)
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("failed to calculate SHA for file %q: %v", file.Filename, err)
|
||
|
}
|
||
|
if calculatedSHA == file.SHA256 {
|
||
|
// SHA matches, skip downloading
|
||
|
log.Debug().Msgf("File %q already exists and matches the SHA. Skipping download", file.Filename)
|
||
|
continue
|
||
|
}
|
||
|
// SHA doesn't match, delete the file and download again
|
||
|
err = os.Remove(filePath)
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("failed to remove existing file %q: %v", file.Filename, err)
|
||
|
}
|
||
|
log.Debug().Msgf("Removed %q (SHA doesn't match)", filePath)
|
||
|
|
||
|
} else {
|
||
|
// SHA is missing, skip downloading
|
||
|
log.Debug().Msgf("File %q already exists. Skipping download", file.Filename)
|
||
|
continue
|
||
|
}
|
||
|
} else if !os.IsNotExist(err) {
|
||
|
// Error occurred while checking file existence
|
||
|
return fmt.Errorf("failed to check file %q existence: %v", file.Filename, err)
|
||
|
}
|
||
|
|
||
|
log.Debug().Msgf("Downloading %q", file.URI)
|
||
|
|
||
|
// Download file
|
||
|
resp, err := http.Get(file.URI)
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("failed to download file %q: %v", file.Filename, err)
|
||
|
}
|
||
|
defer resp.Body.Close()
|
||
|
|
||
|
// Create parent directory
|
||
|
err = os.MkdirAll(filepath.Dir(filePath), 0755)
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("failed to create parent directory for file %q: %v", file.Filename, err)
|
||
|
}
|
||
|
|
||
|
// Create and write file content
|
||
|
outFile, err := os.Create(filePath)
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("failed to create file %q: %v", file.Filename, err)
|
||
|
}
|
||
|
defer outFile.Close()
|
||
|
|
||
|
if file.SHA256 != "" {
|
||
|
log.Debug().Msgf("Download and verifying %q", file.Filename)
|
||
|
|
||
|
// Write file content and calculate SHA
|
||
|
hash := sha256.New()
|
||
|
_, err = io.Copy(io.MultiWriter(outFile, hash), resp.Body)
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("failed to write file %q: %v", file.Filename, err)
|
||
|
}
|
||
|
|
||
|
// Verify SHA
|
||
|
calculatedSHA := fmt.Sprintf("%x", hash.Sum(nil))
|
||
|
if calculatedSHA != file.SHA256 {
|
||
|
return fmt.Errorf("SHA mismatch for file %q ( calculated: %s != metadata: %s )", file.Filename, calculatedSHA, file.SHA256)
|
||
|
}
|
||
|
} else {
|
||
|
log.Debug().Msgf("SHA missing for %q. Skipping validation", file.Filename)
|
||
|
_, err = io.Copy(outFile, resp.Body)
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("failed to write file %q: %v", file.Filename, err)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
log.Debug().Msgf("File %q downloaded and verified", file.Filename)
|
||
|
}
|
||
|
|
||
|
// Write prompt template contents to separate files
|
||
|
for _, template := range config.PromptTemplates {
|
||
|
// Create file path
|
||
|
filePath := filepath.Join(basePath, template.Name+".tmpl")
|
||
|
|
||
|
// Create parent directory
|
||
|
err := os.MkdirAll(filepath.Dir(filePath), 0755)
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("failed to create parent directory for prompt template %q: %v", template.Name, err)
|
||
|
}
|
||
|
// Create and write file content
|
||
|
err = os.WriteFile(filePath, []byte(template.Content), 0644)
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("failed to write prompt template %q: %v", template.Name, err)
|
||
|
}
|
||
|
|
||
|
log.Debug().Msgf("Prompt template %q written", template.Name)
|
||
|
}
|
||
|
|
||
|
name := config.Name
|
||
|
if nameOverride != "" {
|
||
|
name = nameOverride
|
||
|
}
|
||
|
|
||
|
configFilePath := filepath.Join(basePath, name+".yaml")
|
||
|
|
||
|
// Read and update config file as map[string]interface{}
|
||
|
configMap := make(map[string]interface{})
|
||
|
err = yaml.Unmarshal([]byte(config.ConfigFile), &configMap)
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("failed to unmarshal config YAML: %v", err)
|
||
|
}
|
||
|
|
||
|
configMap["name"] = name
|
||
|
|
||
|
// Write updated config file
|
||
|
updatedConfigYAML, err := yaml.Marshal(configMap)
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("failed to marshal updated config YAML: %v", err)
|
||
|
}
|
||
|
|
||
|
err = os.WriteFile(configFilePath, updatedConfigYAML, 0644)
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("failed to write updated config file: %v", err)
|
||
|
}
|
||
|
|
||
|
log.Debug().Msgf("Written config file %s", configFilePath)
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func calculateSHA(filePath string) (string, error) {
|
||
|
file, err := os.Open(filePath)
|
||
|
if err != nil {
|
||
|
return "", err
|
||
|
}
|
||
|
defer file.Close()
|
||
|
|
||
|
hash := sha256.New()
|
||
|
if _, err := io.Copy(hash, file); err != nil {
|
||
|
return "", err
|
||
|
}
|
||
|
|
||
|
return fmt.Sprintf("%x", hash.Sum(nil)), nil
|
||
|
}
|