From 17a9719f3b589f0cb4350a87ffa75ccdad974ab2 Mon Sep 17 00:00:00 2001 From: Jesse Duffield Date: Sun, 14 Jun 2020 13:44:11 +1000 Subject: [PATCH] better separation of concerns --- main.go | 20 +++- pkg/commands/bind.go | 167 ++++++++++++++++++--------------- pkg/commands/horcrux.go | 90 ++++++++++++++++++ pkg/commands/horcrux_header.go | 10 -- pkg/commands/split.go | 10 +- pkg/commands/utils.go | 2 +- 6 files changed, 208 insertions(+), 91 deletions(-) create mode 100644 pkg/commands/horcrux.go delete mode 100644 pkg/commands/horcrux_header.go diff --git a/main.go b/main.go index 7668f74..7503f23 100644 --- a/main.go +++ b/main.go @@ -21,9 +21,27 @@ func main() { } else { dir = os.Args[2] } - if err := commands.Bind(dir); err != nil { + paths, err := commands.GetHorcruxPathsInDir(dir) + if err != nil { log.Fatal(err) } + overwrite := false + for { + if err := commands.Bind(paths, "", overwrite); err != nil { + if err != os.ErrExist { + log.Fatal(err) + } + overwriteResponse := commands.Prompt("A file already exists at destination. Overwrite? (Y/N):") + if overwriteResponse == "Y" || overwriteResponse == "y" || overwriteResponse == "yes" { + overwrite = true + } else { + log.Fatal("You have chosen not to overwrite the file. Cancelling.") + } + } else { + break + } + } + return } diff --git a/pkg/commands/bind.go b/pkg/commands/bind.go index 45a97b1..22fb36c 100644 --- a/pkg/commands/bind.go +++ b/pkg/commands/bind.go @@ -1,71 +1,125 @@ package commands import ( - "bufio" - "encoding/json" "errors" "fmt" "io" "io/ioutil" "os" "path/filepath" + "sort" + "strings" "github.com/jesseduffield/horcrux/pkg/multiplexing" "github.com/jesseduffield/horcrux/pkg/shamir" ) -func Bind(dir string) error { +func GetHorcruxPathsInDir(dir string) ([]string, error) { files, err := ioutil.ReadDir(dir) if err != nil { - return err + return nil, err } - filenames := []string{} + paths := []string{} for _, file := range files { if filepath.Ext(file.Name()) == ".horcrux" { - filenames = append(filenames, file.Name()) + paths = append(paths, file.Name()) } } - headers := []horcruxHeader{} - horcruxFiles := []*os.File{} + return paths, nil +} - for _, filename := range filenames { - file, err := os.Open(filename) - defer file.Close() +type byIndex []Horcrux + +func (h byIndex) Len() int { + return len(h) +} + +func (h byIndex) Swap(i, j int) { + h[i], h[j] = h[j], h[i] +} + +func (h byIndex) Less(i, j int) bool { + return h[i].GetHeader().Index < h[j].GetHeader().Index +} + +func GetHorcruxes(paths []string) ([]Horcrux, error) { + horcruxes := []Horcrux{} + + for _, path := range paths { + currentHorcrux, err := NewHorcrux(path) if err != nil { - return err + return nil, err } - - currentHeader, err := getHeaderFromHorcruxFile(file) - if err != nil { - return err - } - - for _, header := range headers { - if header.Index == currentHeader.Index { + for _, horcrux := range horcruxes { + if horcrux.GetHeader().Index == currentHorcrux.GetHeader().Index && horcrux.GetHeader().OriginalFilename == currentHorcrux.GetHeader().OriginalFilename { // we've already obtained this horcrux so we'll skip this instance continue } } - if len(headers) > 0 && (currentHeader.OriginalFilename != headers[0].OriginalFilename || currentHeader.Timestamp != headers[0].Timestamp) { + horcruxes = append(horcruxes, *currentHorcrux) + } + + sort.Sort(byIndex(horcruxes)) + + return horcruxes, nil +} + +func ValidateHorcruxes(horcruxes []Horcrux) error { + if len(horcruxes) == 0 { + return errors.New("No horcruxes supplied") + } + + if len(horcruxes) < horcruxes[0].GetHeader().Threshold { + return fmt.Errorf( + "You do not have all the required horcruxes. There are %d required to resurrect the original file. You only have %d", + horcruxes[0].GetHeader().Threshold, + len(horcruxes), + ) + } + + for _, horcrux := range horcruxes { + if !strings.HasSuffix(horcrux.GetPath(), ".horcrux") { + return fmt.Errorf("%s is not a horcrux file (requires .horcrux extension)", horcrux.GetPath()) + } + if horcrux.GetHeader().OriginalFilename != horcruxes[0].GetHeader().OriginalFilename || horcrux.GetHeader().Timestamp != horcruxes[0].GetHeader().Timestamp { return errors.New("All horcruxes in the given directory must have the same original filename and timestamp.") } - - headers = append(headers, *currentHeader) - horcruxFiles = append(horcruxFiles, file) } - if len(headers) == 0 { - return errors.New("No horcruxes in directory") - } else if len(headers) < headers[0].Threshold { - return errors.New(fmt.Sprintf("You do not have all the required horcruxes. There are %d required to resurrect the original file. You only have %d", headers[0].Threshold, len(headers))) + return nil +} + +func Bind(paths []string, dstPath string, overwrite bool) error { + horcruxes, err := GetHorcruxes(paths) + if err != nil { + return err } - keyFragments := make([][]byte, len(headers)) + if err := ValidateHorcruxes(horcruxes); err != nil { + return err + } + + firstHorcrux := horcruxes[0] + + // if dstPath is empty we use the original filename + if dstPath == "" { + cwd, err := os.Getwd() + if err != nil { + return err + } + dstPath = filepath.Join(cwd, firstHorcrux.GetHeader().OriginalFilename) + } + + if fileExists(dstPath) && !overwrite { + return os.ErrExist + } + + keyFragments := make([][]byte, len(horcruxes)) for i := range keyFragments { - keyFragments[i] = headers[i].KeyFragment + keyFragments[i] = horcruxes[i].GetHeader().KeyFragment } key, err := shamir.Combine(keyFragments) @@ -74,28 +128,22 @@ func Bind(dir string) error { } var fileReader io.Reader - if headers[0].Total == headers[0].Threshold { - // sort by index - orderedHorcruxFiles := make([]*os.File, len(horcruxFiles)) - for i, h := range horcruxFiles { - orderedHorcruxFiles[headers[i].Index-1] = h + if firstHorcrux.GetHeader().Total == firstHorcrux.GetHeader().Threshold { + horcruxFiles := make([]*os.File, len(horcruxes)) + for i, horcrux := range horcruxes { + horcruxFiles[i] = horcrux.GetFile() } - fileReader = &multiplexing.Multiplexer{Readers: orderedHorcruxFiles} + fileReader = &multiplexing.Multiplexer{Readers: horcruxFiles} } else { - fileReader = horcruxFiles[0] // arbitrarily read from the first horcrux: they all contain the same contents + fileReader = firstHorcrux.GetFile() // arbitrarily read from the first horcrux: they all contain the same contents } reader := cryptoReader(fileReader, key) - newFilename := headers[0].OriginalFilename - if fileExists(newFilename) { - newFilename = prompt("A file already exists named '%s'. Enter new file name: ", newFilename) - } + _ = os.Truncate(dstPath, 0) - _ = os.Truncate(newFilename, 0) - - newFile, err := os.OpenFile(newFilename, os.O_WRONLY|os.O_CREATE, 0644) + newFile, err := os.OpenFile(dstPath, os.O_WRONLY|os.O_CREATE, 0644) if err != nil { return err } @@ -108,34 +156,3 @@ func Bind(dir string) error { return err } - -// this function gets the header from the horcrux file and ensures that we leave -// the file with its read pointer at the start of the encrypted content -// so that we can later directly read from that point -// yes this is a side effect, no I'm not proud of it. -func getHeaderFromHorcruxFile(file *os.File) (*horcruxHeader, error) { - currentHeader := &horcruxHeader{} - scanner := bufio.NewScanner(file) - bytesBeforeBody := 0 - for scanner.Scan() { - line := scanner.Text() - bytesBeforeBody += len(scanner.Bytes()) + 1 - if line == "-- HEADER --" { - scanner.Scan() - bytesBeforeBody += len(scanner.Bytes()) + 1 - headerLine := scanner.Bytes() - json.Unmarshal(headerLine, currentHeader) - scanner.Scan() // one more to get past the body line - bytesBeforeBody += len(scanner.Bytes()) + 1 - break - } - } - if _, err := file.Seek(int64(bytesBeforeBody), io.SeekStart); err != nil { - return nil, err - } - - if currentHeader == nil { - return nil, errors.New("could not find header in horcrux file") - } - return currentHeader, nil -} diff --git a/pkg/commands/horcrux.go b/pkg/commands/horcrux.go new file mode 100644 index 0000000..7268cfe --- /dev/null +++ b/pkg/commands/horcrux.go @@ -0,0 +1,90 @@ +package commands + +import ( + "bufio" + "encoding/json" + "errors" + "io" + "os" +) + +type HorcruxHeader struct { + OriginalFilename string `json:"originalFilename"` + Timestamp int64 `json:"timestamp"` + Index int `json:"index"` + Total int `json:"total"` + Threshold int `json:"threshold"` + KeyFragment []byte `json:"keyFragment"` +} + +type Horcrux struct { + path string + header HorcruxHeader + file *os.File +} + +// returns a horcrux with its header parsed, and it's file's read pointer +// right after the header. +func NewHorcrux(path string) (*Horcrux, error) { + file, err := os.Open(path) + if err != nil { + return nil, err + } + + header, err := GetHeaderFromHorcruxFile(file) + if err != nil { + return nil, err + } + + return &Horcrux{ + path: path, + file: file, + header: *header, + }, nil +} + +// this function gets the header from the horcrux file and ensures that we leave +// the file with its read pointer at the start of the encrypted content +// so that we can later directly read from that point +// yes this is a side effect, no I'm not proud of it. +func GetHeaderFromHorcruxFile(file *os.File) (*HorcruxHeader, error) { + currentHeader := &HorcruxHeader{} + scanner := bufio.NewScanner(file) + bytesBeforeBody := 0 + for scanner.Scan() { + line := scanner.Text() + bytesBeforeBody += len(scanner.Bytes()) + 1 + if line == "-- HEADER --" { + scanner.Scan() + bytesBeforeBody += len(scanner.Bytes()) + 1 + headerLine := scanner.Bytes() + if err := json.Unmarshal(headerLine, currentHeader); err != nil { + return nil, err + } + + scanner.Scan() // one more to get past the body line + bytesBeforeBody += len(scanner.Bytes()) + 1 + break + } + } + if _, err := file.Seek(int64(bytesBeforeBody), io.SeekStart); err != nil { + return nil, err + } + + if currentHeader == nil { + return nil, errors.New("could not find header in horcrux file") + } + return currentHeader, nil +} + +func (h *Horcrux) GetHeader() HorcruxHeader { + return h.header +} + +func (h *Horcrux) GetPath() string { + return h.path +} + +func (h *Horcrux) GetFile() *os.File { + return h.file +} diff --git a/pkg/commands/horcrux_header.go b/pkg/commands/horcrux_header.go deleted file mode 100644 index 4c09bba..0000000 --- a/pkg/commands/horcrux_header.go +++ /dev/null @@ -1,10 +0,0 @@ -package commands - -type horcruxHeader struct { - OriginalFilename string `json:"originalFilename"` - Timestamp int64 `json:"timestamp"` - Index int `json:"index"` - Total int `json:"total"` - Threshold int `json:"threshold"` - KeyFragment []byte `json:"keyFragment"` -} diff --git a/pkg/commands/split.go b/pkg/commands/split.go index db3dad3..6632390 100644 --- a/pkg/commands/split.go +++ b/pkg/commands/split.go @@ -64,7 +64,7 @@ func Split(path string, destination string, total int, threshold int) error { for i := range horcruxFiles { index := i + 1 - headerBytes, err := json.Marshal(&horcruxHeader{ + headerBytes, err := json.Marshal(&HorcruxHeader{ OriginalFilename: originalFilename, Timestamp: timestamp, Index: index, @@ -90,7 +90,9 @@ func Split(path string, destination string, total int, threshold int) error { } defer horcruxFile.Close() - horcruxFile.WriteString(header(index, total, headerBytes)) + if _, err := horcruxFile.WriteString(header(index, total, headerBytes)); err != nil { + return err + } horcruxFiles[i] = horcruxFile } @@ -133,7 +135,7 @@ func obtainTotalAndThreshold() (int, int, error) { threshold := *thresholdPtr if total == 0 { - totalStr := prompt("How many horcruxes do you want to split this file into? (2-99): ") + totalStr := Prompt("How many horcruxes do you want to split this file into? (2-99): ") var err error total, err = strconv.Atoi(totalStr) if err != nil { @@ -142,7 +144,7 @@ func obtainTotalAndThreshold() (int, int, error) { } if threshold == 0 { - thresholdStr := prompt("How many horcruxes should be required to reconstitute the original file? If you require all horcruxes, the resulting files will take up less space, but it will feel less magical (2-99): ") + thresholdStr := Prompt("How many horcruxes should be required to reconstitute the original file? If you require all horcruxes, the resulting files will take up less space, but it will feel less magical (2-99): ") var err error threshold, err = strconv.Atoi(thresholdStr) if err != nil { diff --git a/pkg/commands/utils.go b/pkg/commands/utils.go index 489efc7..8d3bd6b 100644 --- a/pkg/commands/utils.go +++ b/pkg/commands/utils.go @@ -30,7 +30,7 @@ func fileExists(filename string) bool { return !info.IsDir() } -func prompt(message string, args ...interface{}) string { +func Prompt(message string, args ...interface{}) string { reader := bufio.NewReader(os.Stdin) fmt.Printf(message, args...) input, _ := reader.ReadString('\n')