better separation of concerns

This commit is contained in:
Jesse Duffield 2020-06-14 13:44:11 +10:00
parent 62f26a105e
commit 17a9719f3b
6 changed files with 208 additions and 91 deletions

20
main.go
View file

@ -21,9 +21,27 @@ func main() {
} else {
dir = os.Args[2]
}
if err := commands.Bind(dir); err != nil {
paths, err := commands.GetHorcruxPathsInDir(dir)
if err != nil {
log.Fatal(err)
}
overwrite := false
for {
if err := commands.Bind(paths, "", overwrite); err != nil {
if err != os.ErrExist {
log.Fatal(err)
}
overwriteResponse := commands.Prompt("A file already exists at destination. Overwrite? (Y/N):")
if overwriteResponse == "Y" || overwriteResponse == "y" || overwriteResponse == "yes" {
overwrite = true
} else {
log.Fatal("You have chosen not to overwrite the file. Cancelling.")
}
} else {
break
}
}
return
}

View file

@ -1,71 +1,125 @@
package commands
import (
"bufio"
"encoding/json"
"errors"
"fmt"
"io"
"io/ioutil"
"os"
"path/filepath"
"sort"
"strings"
"github.com/jesseduffield/horcrux/pkg/multiplexing"
"github.com/jesseduffield/horcrux/pkg/shamir"
)
func Bind(dir string) error {
func GetHorcruxPathsInDir(dir string) ([]string, error) {
files, err := ioutil.ReadDir(dir)
if err != nil {
return err
return nil, err
}
filenames := []string{}
paths := []string{}
for _, file := range files {
if filepath.Ext(file.Name()) == ".horcrux" {
filenames = append(filenames, file.Name())
paths = append(paths, file.Name())
}
}
headers := []horcruxHeader{}
horcruxFiles := []*os.File{}
return paths, nil
}
for _, filename := range filenames {
file, err := os.Open(filename)
defer file.Close()
type byIndex []Horcrux
func (h byIndex) Len() int {
return len(h)
}
func (h byIndex) Swap(i, j int) {
h[i], h[j] = h[j], h[i]
}
func (h byIndex) Less(i, j int) bool {
return h[i].GetHeader().Index < h[j].GetHeader().Index
}
func GetHorcruxes(paths []string) ([]Horcrux, error) {
horcruxes := []Horcrux{}
for _, path := range paths {
currentHorcrux, err := NewHorcrux(path)
if err != nil {
return err
return nil, err
}
currentHeader, err := getHeaderFromHorcruxFile(file)
if err != nil {
return err
}
for _, header := range headers {
if header.Index == currentHeader.Index {
for _, horcrux := range horcruxes {
if horcrux.GetHeader().Index == currentHorcrux.GetHeader().Index && horcrux.GetHeader().OriginalFilename == currentHorcrux.GetHeader().OriginalFilename {
// we've already obtained this horcrux so we'll skip this instance
continue
}
}
if len(headers) > 0 && (currentHeader.OriginalFilename != headers[0].OriginalFilename || currentHeader.Timestamp != headers[0].Timestamp) {
horcruxes = append(horcruxes, *currentHorcrux)
}
sort.Sort(byIndex(horcruxes))
return horcruxes, nil
}
func ValidateHorcruxes(horcruxes []Horcrux) error {
if len(horcruxes) == 0 {
return errors.New("No horcruxes supplied")
}
if len(horcruxes) < horcruxes[0].GetHeader().Threshold {
return fmt.Errorf(
"You do not have all the required horcruxes. There are %d required to resurrect the original file. You only have %d",
horcruxes[0].GetHeader().Threshold,
len(horcruxes),
)
}
for _, horcrux := range horcruxes {
if !strings.HasSuffix(horcrux.GetPath(), ".horcrux") {
return fmt.Errorf("%s is not a horcrux file (requires .horcrux extension)", horcrux.GetPath())
}
if horcrux.GetHeader().OriginalFilename != horcruxes[0].GetHeader().OriginalFilename || horcrux.GetHeader().Timestamp != horcruxes[0].GetHeader().Timestamp {
return errors.New("All horcruxes in the given directory must have the same original filename and timestamp.")
}
headers = append(headers, *currentHeader)
horcruxFiles = append(horcruxFiles, file)
}
if len(headers) == 0 {
return errors.New("No horcruxes in directory")
} else if len(headers) < headers[0].Threshold {
return errors.New(fmt.Sprintf("You do not have all the required horcruxes. There are %d required to resurrect the original file. You only have %d", headers[0].Threshold, len(headers)))
return nil
}
func Bind(paths []string, dstPath string, overwrite bool) error {
horcruxes, err := GetHorcruxes(paths)
if err != nil {
return err
}
keyFragments := make([][]byte, len(headers))
if err := ValidateHorcruxes(horcruxes); err != nil {
return err
}
firstHorcrux := horcruxes[0]
// if dstPath is empty we use the original filename
if dstPath == "" {
cwd, err := os.Getwd()
if err != nil {
return err
}
dstPath = filepath.Join(cwd, firstHorcrux.GetHeader().OriginalFilename)
}
if fileExists(dstPath) && !overwrite {
return os.ErrExist
}
keyFragments := make([][]byte, len(horcruxes))
for i := range keyFragments {
keyFragments[i] = headers[i].KeyFragment
keyFragments[i] = horcruxes[i].GetHeader().KeyFragment
}
key, err := shamir.Combine(keyFragments)
@ -74,28 +128,22 @@ func Bind(dir string) error {
}
var fileReader io.Reader
if headers[0].Total == headers[0].Threshold {
// sort by index
orderedHorcruxFiles := make([]*os.File, len(horcruxFiles))
for i, h := range horcruxFiles {
orderedHorcruxFiles[headers[i].Index-1] = h
if firstHorcrux.GetHeader().Total == firstHorcrux.GetHeader().Threshold {
horcruxFiles := make([]*os.File, len(horcruxes))
for i, horcrux := range horcruxes {
horcruxFiles[i] = horcrux.GetFile()
}
fileReader = &multiplexing.Multiplexer{Readers: orderedHorcruxFiles}
fileReader = &multiplexing.Multiplexer{Readers: horcruxFiles}
} else {
fileReader = horcruxFiles[0] // arbitrarily read from the first horcrux: they all contain the same contents
fileReader = firstHorcrux.GetFile() // arbitrarily read from the first horcrux: they all contain the same contents
}
reader := cryptoReader(fileReader, key)
newFilename := headers[0].OriginalFilename
if fileExists(newFilename) {
newFilename = prompt("A file already exists named '%s'. Enter new file name: ", newFilename)
}
_ = os.Truncate(dstPath, 0)
_ = os.Truncate(newFilename, 0)
newFile, err := os.OpenFile(newFilename, os.O_WRONLY|os.O_CREATE, 0644)
newFile, err := os.OpenFile(dstPath, os.O_WRONLY|os.O_CREATE, 0644)
if err != nil {
return err
}
@ -108,34 +156,3 @@ func Bind(dir string) error {
return err
}
// this function gets the header from the horcrux file and ensures that we leave
// the file with its read pointer at the start of the encrypted content
// so that we can later directly read from that point
// yes this is a side effect, no I'm not proud of it.
func getHeaderFromHorcruxFile(file *os.File) (*horcruxHeader, error) {
currentHeader := &horcruxHeader{}
scanner := bufio.NewScanner(file)
bytesBeforeBody := 0
for scanner.Scan() {
line := scanner.Text()
bytesBeforeBody += len(scanner.Bytes()) + 1
if line == "-- HEADER --" {
scanner.Scan()
bytesBeforeBody += len(scanner.Bytes()) + 1
headerLine := scanner.Bytes()
json.Unmarshal(headerLine, currentHeader)
scanner.Scan() // one more to get past the body line
bytesBeforeBody += len(scanner.Bytes()) + 1
break
}
}
if _, err := file.Seek(int64(bytesBeforeBody), io.SeekStart); err != nil {
return nil, err
}
if currentHeader == nil {
return nil, errors.New("could not find header in horcrux file")
}
return currentHeader, nil
}

90
pkg/commands/horcrux.go Normal file
View file

@ -0,0 +1,90 @@
package commands
import (
"bufio"
"encoding/json"
"errors"
"io"
"os"
)
type HorcruxHeader struct {
OriginalFilename string `json:"originalFilename"`
Timestamp int64 `json:"timestamp"`
Index int `json:"index"`
Total int `json:"total"`
Threshold int `json:"threshold"`
KeyFragment []byte `json:"keyFragment"`
}
type Horcrux struct {
path string
header HorcruxHeader
file *os.File
}
// returns a horcrux with its header parsed, and it's file's read pointer
// right after the header.
func NewHorcrux(path string) (*Horcrux, error) {
file, err := os.Open(path)
if err != nil {
return nil, err
}
header, err := GetHeaderFromHorcruxFile(file)
if err != nil {
return nil, err
}
return &Horcrux{
path: path,
file: file,
header: *header,
}, nil
}
// this function gets the header from the horcrux file and ensures that we leave
// the file with its read pointer at the start of the encrypted content
// so that we can later directly read from that point
// yes this is a side effect, no I'm not proud of it.
func GetHeaderFromHorcruxFile(file *os.File) (*HorcruxHeader, error) {
currentHeader := &HorcruxHeader{}
scanner := bufio.NewScanner(file)
bytesBeforeBody := 0
for scanner.Scan() {
line := scanner.Text()
bytesBeforeBody += len(scanner.Bytes()) + 1
if line == "-- HEADER --" {
scanner.Scan()
bytesBeforeBody += len(scanner.Bytes()) + 1
headerLine := scanner.Bytes()
if err := json.Unmarshal(headerLine, currentHeader); err != nil {
return nil, err
}
scanner.Scan() // one more to get past the body line
bytesBeforeBody += len(scanner.Bytes()) + 1
break
}
}
if _, err := file.Seek(int64(bytesBeforeBody), io.SeekStart); err != nil {
return nil, err
}
if currentHeader == nil {
return nil, errors.New("could not find header in horcrux file")
}
return currentHeader, nil
}
func (h *Horcrux) GetHeader() HorcruxHeader {
return h.header
}
func (h *Horcrux) GetPath() string {
return h.path
}
func (h *Horcrux) GetFile() *os.File {
return h.file
}

View file

@ -1,10 +0,0 @@
package commands
type horcruxHeader struct {
OriginalFilename string `json:"originalFilename"`
Timestamp int64 `json:"timestamp"`
Index int `json:"index"`
Total int `json:"total"`
Threshold int `json:"threshold"`
KeyFragment []byte `json:"keyFragment"`
}

View file

@ -64,7 +64,7 @@ func Split(path string, destination string, total int, threshold int) error {
for i := range horcruxFiles {
index := i + 1
headerBytes, err := json.Marshal(&horcruxHeader{
headerBytes, err := json.Marshal(&HorcruxHeader{
OriginalFilename: originalFilename,
Timestamp: timestamp,
Index: index,
@ -90,7 +90,9 @@ func Split(path string, destination string, total int, threshold int) error {
}
defer horcruxFile.Close()
horcruxFile.WriteString(header(index, total, headerBytes))
if _, err := horcruxFile.WriteString(header(index, total, headerBytes)); err != nil {
return err
}
horcruxFiles[i] = horcruxFile
}
@ -133,7 +135,7 @@ func obtainTotalAndThreshold() (int, int, error) {
threshold := *thresholdPtr
if total == 0 {
totalStr := prompt("How many horcruxes do you want to split this file into? (2-99): ")
totalStr := Prompt("How many horcruxes do you want to split this file into? (2-99): ")
var err error
total, err = strconv.Atoi(totalStr)
if err != nil {
@ -142,7 +144,7 @@ func obtainTotalAndThreshold() (int, int, error) {
}
if threshold == 0 {
thresholdStr := prompt("How many horcruxes should be required to reconstitute the original file? If you require all horcruxes, the resulting files will take up less space, but it will feel less magical (2-99): ")
thresholdStr := Prompt("How many horcruxes should be required to reconstitute the original file? If you require all horcruxes, the resulting files will take up less space, but it will feel less magical (2-99): ")
var err error
threshold, err = strconv.Atoi(thresholdStr)
if err != nil {

View file

@ -30,7 +30,7 @@ func fileExists(filename string) bool {
return !info.IsDir()
}
func prompt(message string, args ...interface{}) string {
func Prompt(message string, args ...interface{}) string {
reader := bufio.NewReader(os.Stdin)
fmt.Printf(message, args...)
input, _ := reader.ReadString('\n')