diff --git a/cmd/hauler/cli/store/load.go b/cmd/hauler/cli/store/load.go index c93977a..98b8710 100644 --- a/cmd/hauler/cli/store/load.go +++ b/cmd/hauler/cli/store/load.go @@ -2,10 +2,15 @@ package store import ( "context" + "io" + "net/url" "os" + "path/filepath" + "strings" "hauler.dev/go/hauler/internal/flags" "hauler.dev/go/hauler/pkg/archives" + "hauler.dev/go/hauler/pkg/artifacts/file/getter" "hauler.dev/go/hauler/pkg/consts" "hauler.dev/go/hauler/pkg/content" "hauler.dev/go/hauler/pkg/log" @@ -16,9 +21,23 @@ import ( func LoadCmd(ctx context.Context, o *flags.LoadOpts, rso *flags.StoreRootOpts, ro *flags.CliRootOpts) error { l := log.FromContext(ctx) + tempOverride := o.TempOverride + + if tempOverride == "" { + tempOverride = os.Getenv(consts.HaulerTempDir) + } + + tempDir, err := os.MkdirTemp(tempOverride, consts.DefaultHaulerTempDirName) + if err != nil { + return err + } + defer os.RemoveAll(tempDir) + + l.Debugf("using temporary directory at [%s]", tempDir) + for _, fileName := range o.FileName { - l.Infof("loading haul [%s] to [%s]", o.FileName, o.StoreDir) - err := unarchiveLayoutTo(ctx, fileName, o.StoreDir, o.TempOverride) + l.Infof("loading haul [%s] to [%s]", fileName, o.StoreDir) + err := unarchiveLayoutTo(ctx, fileName, o.StoreDir, tempDir) if err != nil { return err } @@ -27,27 +46,20 @@ func LoadCmd(ctx context.Context, o *flags.LoadOpts, rso *flags.StoreRootOpts, r return nil } -// unarchiveLayoutTo accepts an archived OCI layout, extracts the contents to an existing OCI layout, and preserves the index -func unarchiveLayoutTo(ctx context.Context, haulPath string, dest string, tempOverride string) error { +// accepts an archived OCI layout, extracts the contents to an existing OCI layout, and preserves the index +func unarchiveLayoutTo(ctx context.Context, haulPath string, dest string, tempDir string) error { l := log.FromContext(ctx) - var tempDir string - - if tempOverride != "" { - tempDir = tempOverride - } else { - - parent := os.Getenv(consts.HaulerTempDir) + // if archivePath detects a remote URL... download it + if strings.HasPrefix(haulPath, "http://") || strings.HasPrefix(haulPath, "https://") { + l.Debugf("detected remote archive... starting download... [%s]", haulPath) var err error - tempDir, err = os.MkdirTemp(parent, consts.DefaultHaulerTempDirName) + haulPath, err = downloadRemote(ctx, haulPath, tempDir) if err != nil { return err } - defer os.RemoveAll(tempDir) } - l.Debugf("using temporary directory [%s]", tempDir) - if err := archives.Unarchive(ctx, haulPath, tempDir); err != nil { return err } @@ -65,3 +77,35 @@ func unarchiveLayoutTo(ctx context.Context, haulPath string, dest string, tempOv _, err = s.CopyAll(ctx, ts, nil) return err } + +// downloadRemote downloads the remote file using the existing getter +func downloadRemote(ctx context.Context, remoteURL, tempDirDest string) (string, error) { + parsedURL, err := url.Parse(remoteURL) + if err != nil { + return "", err + } + h := getter.NewHttp() + rc, err := h.Open(ctx, parsedURL) + if err != nil { + return "", err + } + defer rc.Close() + + fileName := h.Name(parsedURL) + if fileName == "" { + fileName = filepath.Base(parsedURL.Path) + } + + localPath := filepath.Join(tempDirDest, fileName) + out, err := os.Create(localPath) + if err != nil { + return "", err + } + defer out.Close() + + if _, err = io.Copy(out, rc); err != nil { + return "", err + } + + return localPath, nil +} diff --git a/internal/flags/load.go b/internal/flags/load.go index 66037e7..bc97448 100644 --- a/internal/flags/load.go +++ b/internal/flags/load.go @@ -16,6 +16,6 @@ func (o *LoadOpts) AddFlags(cmd *cobra.Command) { // On Unix systems, the default is $TMPDIR if non-empty, else /tmp // On Windows, the default is GetTempPath, returning the first value from %TMP%, %TEMP%, %USERPROFILE%, or Windows directory - f.StringSliceVarP(&o.FileName, "filename", "f", []string{consts.DefaultHaulerArchiveName}, "Specify the name of haul(s) to sync") + f.StringSliceVarP(&o.FileName, "filename", "f", []string{consts.DefaultHaulerArchiveName}, "(Optional) Specify the name of inputted haul(s)") f.StringVarP(&o.TempOverride, "tempdir", "t", "", "(Optional) Override the default temporary directiory determined by the OS") }