diff --git a/scribble.go b/scribble.go index eb9712d..50f8583 100644 --- a/scribble.go +++ b/scribble.go @@ -99,6 +99,7 @@ func (d *Driver) write(trans Transaction) error { mutex := d.getOrCreateMutex(trans.Collection) mutex.Lock() + defer mutex.Unlock() // dir := d.dir + "/" + trans.Collection @@ -114,14 +115,16 @@ func (d *Driver) write(trans Transaction) error { return err } + finalPath := dir + "/" + trans.ResourceID + ".json" + tmpPath := finalPath + "~" + // write marshaled data to a file, named by the resourceID - if err := ioutil.WriteFile(dir+"/"+trans.ResourceID, b, 0666); err != nil { + if err := ioutil.WriteFile(tmpPath, b, 0666); err != nil { return err } - mutex.Unlock() - - return nil + // move final file into place + return os.Rename(tmpPath, finalPath) } // read does the opposite operation as write. Reading a record from the database @@ -138,11 +141,7 @@ func (d *Driver) read(trans Transaction) error { } // unmarshal data into the transaction.Container - if err := json.Unmarshal(b, trans.Container); err != nil { - return err - } - - return nil + return json.Unmarshal(b, trans.Container) } // readAll does the same operation as read, reading all the records in the specified @@ -174,11 +173,7 @@ func (d *Driver) readAll(trans Transaction) error { } // unmarhsal the read files as a comma delimeted byte array - if err := json.Unmarshal([]byte("["+strings.Join(f, ",")+"]"), trans.Container); err != nil { - return err - } - - return nil + return json.Unmarshal([]byte("["+strings.Join(f, ",")+"]"), trans.Container) } // delete locks that database and then proceeds to remove the record (specified by @@ -187,18 +182,12 @@ func (d *Driver) delete(trans Transaction) error { mutex := d.getOrCreateMutex(trans.Collection) mutex.Lock() + defer mutex.Unlock() dir := d.dir + "/" + trans.Collection // remove record from database - err := os.Remove(dir + "/" + trans.ResourceID) - if err != nil { - return err - } - - mutex.Unlock() - - return nil + return os.Remove(dir + "/" + trans.ResourceID) } // helpers @@ -211,8 +200,8 @@ func (d *Driver) getOrCreateMutex(collection string) sync.Mutex { // if the mutex doesn't exist make it if !ok { - d.mutexes[collection] = sync.Mutex{} - return d.mutexes[collection] + c = sync.Mutex{} + d.mutexes[collection] = c } return c @@ -220,17 +209,17 @@ func (d *Driver) getOrCreateMutex(collection string) sync.Mutex { // mkDir is a simple wrapper that attempts to make a directory at a specified // location -func mkDir(d string) error { +func mkDir(d string) (err error) { // dir, _ := os.Stat(d) - if dir == nil { - err := os.MkdirAll(d, 0755) - if err != nil { - return err - } + switch { + case dir == nil: + err = os.MkdirAll(d, 0755) + case !dir.IsDir(): + err = os.ErrInvalid } - return nil + return }