diff --git a/scrib_test.go b/scrib_test.go new file mode 100644 index 0000000..6966d24 --- /dev/null +++ b/scrib_test.go @@ -0,0 +1,78 @@ +package scribble + +import ( + "fmt" + "sync" + "testing" +) + +type logger struct { + t *testing.T +} + +func (l logger) Fatal(f string, a ...interface{}) { l.t.Fatalf(f, a...) } +func (l logger) Error(f string, a ...interface{}) { l.t.Fatalf(f, a...) } +func (l logger) Warn(f string, a ...interface{}) { l.t.Fatalf(f, a...) } +func (l logger) Info(f string, a ...interface{}) {} +func (l logger) Debug(f string, a ...interface{}) {} +func (l logger) Trace(f string, a ...interface{}) {} + +func TestBasic(t *testing.T) { + var d *Driver + var err error + + if d, err = New("./test-dir", logger{t}); err != nil { + t.Fatal(err) + } + + if err = d.Write("/fish", "big", "small"); err != nil { + t.Fatal(err) + } + + var ans string + + if err = d.Read("/fish/big", &ans); err != nil { + t.Fatal(err) + } + + if ans != "small" { + t.Fatal("Expected 'small' but read back ", ans) + } + + var wg sync.WaitGroup + + wg.Add(1) + go func() { + defer wg.Done() + for i := 0; i < 10; i++ { + if err1 := d.Write("/fish", fmt.Sprintf("num%v", i), fmt.Sprintf("%v", i)); err1 != nil { + t.Fatal(err1) + return + } + } + }() + + wg.Add(1) + go func() { + defer wg.Done() + for i := 10; i < 20; i++ { + if err1 := d.Write("/fish", fmt.Sprintf("num%v", i), fmt.Sprintf("%v", i)); err1 != nil { + t.Fatal(err1) + return + } + } + }() + + wg.Wait() + + var fishes []string + + if err := d.Read("/fish", &fishes); err != nil { + t.Fatal(err) + } + + if len(fishes) != 21 { + t.Fatalf("Expected 21 entries but found %v", len(fishes)) + } + +} diff --git a/scribble.go b/scribble.go index b6ea684..df24bfd 100644 --- a/scribble.go +++ b/scribble.go @@ -9,10 +9,12 @@ package scribble import ( + "bytes" "encoding/json" "fmt" "io/ioutil" "os" + "path/filepath" "strings" "sync" @@ -26,6 +28,7 @@ type ( // a Driver is what is used to interact with the scribble database. It runs // transactions, and provides log output Driver struct { + maplock sync.RWMutex mutexes map[string]sync.Mutex dir string // the directory where scribble will create the database log hatchet.Logger // the logger scribble will log to @@ -35,6 +38,8 @@ type ( // New creates a new scribble database at the desired directory location, and // returns a *Driver to then use for interacting with the database func New(dir string, logger hatchet.Logger) (*Driver, error) { + dir = filepath.Clean(dir) + fmt.Printf("Creating database directory at '%v'...\n", dir) // @@ -67,7 +72,7 @@ func (d *Driver) Write(collection, resource string, v interface{}) error { defer mutex.Unlock() // - dir := d.dir + collection + dir := filepath.Join(d.dir, collection) // b, err := json.MarshalIndent(v, "", "\t") @@ -80,7 +85,7 @@ func (d *Driver) Write(collection, resource string, v interface{}) error { return err } - finalPath := dir + "/" + resource + ".json" + finalPath := filepath.Join(dir, resource+".json") tmpPath := finalPath + "~" // write marshaled data to the temp file @@ -88,6 +93,17 @@ func (d *Driver) Write(collection, resource string, v interface{}) error { return err } + if _, err := os.Stat(finalPath); err == nil { + if _, err = os.Stat(finalPath + ".bak"); err == nil { + if err = os.Remove(finalPath + ".bak"); err != nil { + return err + } + } + if err = os.Rename(finalPath, finalPath+".bak"); err != nil { + return err + } + } + // move final file into place return os.Rename(tmpPath, finalPath) } @@ -95,57 +111,73 @@ func (d *Driver) Write(collection, resource string, v interface{}) error { // Read a record from the database func (d *Driver) Read(path string, v interface{}) error { - dir := d.dir + path + var err error + var fi os.FileInfo + + dir := filepath.Join(d.dir, path) // - fi, err := os.Stat(path) - if err != nil { - return err - } + fi, err = os.Stat(dir) - switch { + if err == nil { + if !fi.Mode().IsDir() { + return fmt.Errorf("Expected path %v to be a folder", path) + } - // if the path is a directory, attempt to read all entries into v - case fi.Mode().IsDir(): + var files []os.FileInfo // read all the files in the transaction.Collection - files, err := ioutil.ReadDir(dir) + files, err = ioutil.ReadDir(dir) if err != nil { // an error here just means the collection is either empty or doesn't exist } + buf := bytes.Buffer{} + + buf.WriteString("[") + // the files read from the database - var f []string + if len(files) > 0 { - // iterate over each of the files, attempting to read the file. If successful - // append the files to the collection of read files - for _, file := range files { - b, err := ioutil.ReadFile(dir + "/" + file.Name()) - if err != nil { - return err + // iterate over each of the files, attempting to read the file. If successful + // append the files to the collection of read files + for _, file := range files { + if !strings.HasSuffix(file.Name(), ".json") { + continue + } + + b, err := ioutil.ReadFile(filepath.Join(dir, file.Name())) + if err != nil { + return err + } + + // append read file + buf.Write(b) + buf.WriteString(",") } - - // append read file - f = append(f, string(b)) + buf.Truncate(buf.Len() - len(",")) } + buf.WriteString("]") + // unmarhsal the read files as a comma delimeted byte array - return json.Unmarshal([]byte("["+strings.Join(f, ",")+"]"), v) - - // if the path is a file, attempt to read the single file - case !fi.Mode().IsDir(): - - // read record from database - b, err := ioutil.ReadFile(dir + ".json") - if err != nil { - return err - } - - // unmarshal data into the transaction.Container - return json.Unmarshal(b, &v) + return json.Unmarshal(buf.Bytes(), v) } - return nil + fi, err = os.Stat(dir + ".json") + if err != nil { + return err + } + + var b []byte + b, err = ioutil.ReadFile(dir + ".json") + if err != nil { + return err + } + + // unmarshal data into the transaction.Container + return json.Unmarshal(b, &v) + } // Delete locks that database and then attempts to remove the collection/resource @@ -165,11 +197,11 @@ func (d *Driver) Delete(path string) error { switch { // remove the collection from database case fi.Mode().IsDir(): - return os.Remove(d.dir + path) + return os.Remove(filepath.Join(d.dir, path)) // remove the record from database default: - return os.Remove(d.dir + path + ".json") + return os.Remove(filepath.Join(d.dir, path, ".json")) } } @@ -177,12 +209,18 @@ func (d *Driver) Delete(path string) error { // is being modfied to avoid unsafe operations func (d *Driver) getOrCreateMutex(collection string) sync.Mutex { + d.maplock.RLock() + c, ok := d.mutexes[collection] + d.maplock.RUnlock() // if the mutex doesn't exist make it if !ok { + + d.maplock.Lock() c = sync.Mutex{} d.mutexes[collection] = c + d.maplock.Unlock() } return c