Skip to content

Commit

Permalink
feat(prepare): allow to specify additional files to download (#1526)
Browse files Browse the repository at this point in the history
  • Loading branch information
mudler authored Jan 1, 2024
1 parent f068efe commit 522659e
Showing 1 changed file with 28 additions and 3 deletions.
31 changes: 28 additions & 3 deletions api/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,14 @@ type Config struct {
// CUDA
// Explicitly enable CUDA or not (some backends might need it)
CUDA bool `yaml:"cuda"`

DownloadFiles []File `yaml:"download_files"`
}

type File struct {
Filename string `yaml:"filename" json:"filename"`
SHA256 string `yaml:"sha256" json:"sha256"`
URI string `yaml:"uri" json:"uri"`
}

type VallE struct {
Expand Down Expand Up @@ -272,10 +280,29 @@ func (cm *ConfigLoader) Preload(modelPath string) error {
cm.Lock()
defer cm.Unlock()

status := func(fileName, current, total string, percent float64) {
utils.DisplayDownloadFunction(fileName, current, total, percent)
}

log.Info().Msgf("Preloading models from %s", modelPath)

for i, config := range cm.configs {

// Download files and verify their SHA
for _, file := range config.DownloadFiles {
log.Debug().Msgf("Checking %q exists and matches SHA", file.Filename)

if err := utils.VerifyPath(file.Filename, modelPath); err != nil {
return err
}
// Create file path
filePath := filepath.Join(modelPath, file.Filename)

if err := utils.DownloadFile(file.URI, filePath, file.SHA256, status); err != nil {
return err
}
}

modelURL := config.PredictionOptions.Model
modelURL = utils.ConvertURL(modelURL)

Expand All @@ -285,9 +312,7 @@ func (cm *ConfigLoader) Preload(modelPath string) error {

// check if file exists
if _, err := os.Stat(filepath.Join(modelPath, md5Name)); errors.Is(err, os.ErrNotExist) {
err := utils.DownloadFile(modelURL, filepath.Join(modelPath, md5Name), "", func(fileName, current, total string, percent float64) {
utils.DisplayDownloadFunction(fileName, current, total, percent)
})
err := utils.DownloadFile(modelURL, filepath.Join(modelPath, md5Name), "", status)
if err != nil {
return err
}
Expand Down

0 comments on commit 522659e

Please sign in to comment.