package client import ( "database/sql" "fmt" "math" "os" "path/filepath" "sync" "forge.redroom.link/yves/meowlib" "github.com/google/uuid" _ "github.com/mattn/go-sqlite3" "google.golang.org/protobuf/proto" ) // One RWMutex per SQLite file path. Entries are never deleted (bounded by // peer count, which is small). RLock for reads, Lock for writes. var dbFileMu sync.Map func getDbFileMutex(path string) *sync.RWMutex { v, _ := dbFileMu.LoadOrStore(path, &sync.RWMutex{}) return v.(*sync.RWMutex) } func withDbWrite(path string, fn func(*sql.DB) error) error { mu := getDbFileMutex(path) mu.Lock() defer mu.Unlock() db, err := sql.Open("sqlite3", path) if err != nil { return err } defer db.Close() return fn(db) } func withDbRead(path string, fn func(*sql.DB) error) error { mu := getDbFileMutex(path) mu.RLock() defer mu.RUnlock() db, err := sql.Open("sqlite3", path) if err != nil { return err } defer db.Close() return fn(db) } func dbPath(cfg *Config, identity *Identity, dbid string) string { return filepath.Join(cfg.StoragePath, identity.Uuid, dbid+cfg.DbSuffix) } func storeMessage(peer *Peer, usermessage *meowlib.UserMessage, filenames []string, password string) error { cfg := GetConfig() identity := cfg.GetIdentity() isNew := len(peer.DbIds) == 0 var dbid string if isNew { dbid = uuid.NewString() peer.DbIds = []string{dbid} identity.Peers.StorePeer(peer) identity.CreateFolder() } else { dbid = peer.DbIds[len(peer.DbIds)-1] } // Detach file attachments — no DB lock needed for file I/O. hiddenFilenames := []string{} if len(usermessage.Files) > 0 { secureDir := filepath.Join(cfg.StoragePath, identity.Uuid, "securefiles") if _, err := os.Stat(secureDir); os.IsNotExist(err) { if err = os.MkdirAll(secureDir, 0755); err != nil { return err } } for _, f := range usermessage.Files { hiddenFilename := uuid.NewString() encData, err := meowlib.SymEncrypt(password, f.Data) if err != nil { return err } hidden := filepath.Join(secureDir, hiddenFilename) os.WriteFile(hidden, encData, 0600) hiddenFilenames = append(hiddenFilenames, hidden) f.Data = []byte(hidden) } } outbound := usermessage.From != peer.ContactPublicKey dbm := UserMessageToDbMessage(outbound, usermessage, hiddenFilenames) out, err := proto.Marshal(dbm) if err != nil { return err } encData, err := meowlib.SymEncrypt(password, out) if err != nil { return err } var id int64 path := dbPath(cfg, identity, dbid) err = withDbWrite(path, func(db *sql.DB) error { // SQLite creates the file on first Open; create the table if new DB. if isNew { if err := createMessageTable(db); err != nil { return err } } stmt, err := db.Prepare(`INSERT INTO message(m) VALUES (?) RETURNING ID`) if err != nil { return err } result, err := stmt.Exec(encData) if err != nil { return err } id, err = result.LastInsertId() return err }) if err != nil { return err } peer.LastMessage = DbMessageToInternalUserMessage(id, dbid, dbm) identity.Peers.StorePeer(peer) return nil } func loadNewMessages(peer *Peer, lastDbId int, password string) ([]*InternalUserMessage, error) { var messages []*InternalUserMessage cfg := GetConfig() identity := cfg.GetIdentity() if len(peer.DbIds) == 0 { return messages, nil } fileidx := len(peer.DbIds) - 1 if lastDbId == 0 { lastDbId = math.MaxInt64 } err := withDbRead(dbPath(cfg, identity, peer.DbIds[fileidx]), func(db *sql.DB) error { stm, err := db.Prepare("SELECT id, m FROM message WHERE id > ? ORDER BY id DESC") if err != nil { return err } defer stm.Close() rows, err := stm.Query(lastDbId) if err != nil { return err } defer rows.Close() for rows.Next() { var id int64 var m []byte if err = rows.Scan(&id, &m); err != nil { return err } decdata, err := meowlib.SymDecrypt(password, m) if err != nil { return err } var dbm meowlib.DbMessage if err = proto.Unmarshal(decdata, &dbm); err != nil { return err } ium := DbMessageToInternalUserMessage(id, peer.DbIds[fileidx], &dbm) ium.Dbid = id ium.Dbfile = peer.DbIds[fileidx] messages = append(messages, ium) } return nil }) // TODO DB overlap return messages, err } func loadMessagesHistory(peer *Peer, inAppMsgCount int, lastDbId int, wantMore int, password string) ([]InternalUserMessage, error) { var messages []InternalUserMessage cfg := GetConfig() if len(peer.DbIds) == 0 { return messages, nil } fileidx := len(peer.DbIds) - 1 countStack, err := getMessageCount(peer.DbIds[fileidx]) if err != nil { return nil, err } for inAppMsgCount > countStack { fileidx-- if fileidx < 0 { return nil, nil } newCount, err := getMessageCount(peer.DbIds[fileidx]) if err != nil { return nil, err } countStack += newCount } if lastDbId == 0 { lastDbId = math.MaxInt64 } err = withDbRead(filepath.Join(cfg.StoragePath, cfg.GetIdentity().Uuid, peer.DbIds[fileidx]+cfg.DbSuffix), func(db *sql.DB) error { stm, err := db.Prepare("SELECT id, m FROM message WHERE id < ? ORDER BY id DESC LIMIT ?") if err != nil { return err } defer stm.Close() rows, err := stm.Query(lastDbId, wantMore) if err != nil { return err } defer rows.Close() for rows.Next() { var id int64 var m []byte if err = rows.Scan(&id, &m); err != nil { return err } decdata, err := meowlib.SymDecrypt(password, m) if err != nil { return err } var dbm meowlib.DbMessage if err = proto.Unmarshal(decdata, &dbm); err != nil { return err } ium := DbMessageToInternalUserMessage(id, peer.DbIds[fileidx], &dbm) ium.Dbid = id ium.Dbfile = peer.DbIds[fileidx] messages = append(messages, *ium) } return nil }) // TODO DB overlap return messages, err } func GetDbMessage(dbFile string, dbId int64, password string) (*meowlib.DbMessage, error) { cfg := GetConfig() path := filepath.Join(cfg.StoragePath, cfg.GetIdentity().Uuid, dbFile+cfg.DbSuffix) var dbm meowlib.DbMessage found := false err := withDbRead(path, func(db *sql.DB) error { stm, err := db.Prepare("SELECT id, m FROM message WHERE id=?") if err != nil { return err } defer stm.Close() rows, err := stm.Query(dbId) if err != nil { return err } defer rows.Close() for rows.Next() { found = true var id int64 var m []byte if err = rows.Scan(&id, &m); err != nil { return err } decdata, err := meowlib.SymDecrypt(password, m) if err != nil { return err } if err = proto.Unmarshal(decdata, &dbm); err != nil { return err } } return nil }) if err != nil { return nil, err } if !found { return nil, fmt.Errorf("message row %d not found in %s", dbId, dbFile) } return &dbm, nil } func UpdateDbMessage(dbm *meowlib.DbMessage, dbFile string, dbId int64, password string) error { cfg := GetConfig() path := filepath.Join(cfg.StoragePath, cfg.GetIdentity().Uuid, dbFile+cfg.DbSuffix) out, err := proto.Marshal(dbm) if err != nil { return err } encData, err := meowlib.SymEncrypt(password, out) if err != nil { return err } return withDbWrite(path, func(db *sql.DB) error { stmt, err := db.Prepare(`UPDATE message SET m=? WHERE id=?`) if err != nil { return err } _, err = stmt.Exec(encData, dbId) return err }) } func GetMessagePreview(dbFile string, dbId int64, password string) ([]byte, error) { dbm, err := GetDbMessage(dbFile, dbId, password) if err != nil { return nil, err } return FilePreview(dbm.FilePaths[0], password) } func FilePreview(filename string, password string) ([]byte, error) { encData, err := os.ReadFile(filename) if err != nil { return nil, err } return meowlib.SymDecrypt(password, encData) } func InternalUserMessagePreview(msg *InternalUserMessage, password string) ([]byte, error) { if len(msg.FilePaths) == 0 { return nil, nil } return FilePreview(msg.FilePaths[0], password) } func getMessageCount(dbid string) (int, error) { cfg := GetConfig() path := filepath.Join(cfg.StoragePath, cfg.GetIdentity().Uuid, dbid+cfg.DbSuffix) var count int err := withDbRead(path, func(db *sql.DB) error { return db.QueryRow("SELECT COUNT(*) FROM message").Scan(&count) }) return count, err } // SetMessageServerDelivery updates the server delivery UUID and timestamp for a stored message. func SetMessageServerDelivery(dbFile string, dbId int64, serverUid string, receiveTime uint64, password string) error { dbm, err := GetDbMessage(dbFile, dbId, password) if err != nil { return err } dbm.ServerDeliveryUuid = serverUid dbm.ServerDeliveryTimestamp = receiveTime return UpdateDbMessage(dbm, dbFile, dbId, password) } // FindMessageByUuid scans all DB files for a peer (newest first) and returns // the dbFile, row ID, and DbMessage for the message whose Status.Uuid matches. func FindMessageByUuid(peer *Peer, messageUuid string, password string) (string, int64, *meowlib.DbMessage, error) { cfg := GetConfig() identity := cfg.GetIdentity() for i := len(peer.DbIds) - 1; i >= 0; i-- { dbid := peer.DbIds[i] path := filepath.Join(cfg.StoragePath, identity.Uuid, dbid+cfg.DbSuffix) var foundFile string var foundId int64 var foundMsg meowlib.DbMessage err := withDbRead(path, func(db *sql.DB) error { rows, err := db.Query("SELECT id, m FROM message ORDER BY id DESC") if err != nil { return err } defer rows.Close() for rows.Next() { var id int64 var m []byte if err := rows.Scan(&id, &m); err != nil { continue } decdata, err := meowlib.SymDecrypt(password, m) if err != nil { continue } var dbm meowlib.DbMessage if err := proto.Unmarshal(decdata, &dbm); err != nil { continue } if dbm.Status != nil && dbm.Status.Uuid == messageUuid { foundFile = dbid foundId = id foundMsg = dbm return nil } } return nil }) if err == nil && foundFile != "" { return foundFile, foundId, &foundMsg, nil } } return "", 0, nil, fmt.Errorf("message with UUID %s not found", messageUuid) } // UpdateMessageAck finds a stored outbound message by UUID and stamps it with // the received and/or processed timestamps from an inbound ACK message. func UpdateMessageAck(peer *Peer, messageUuid string, receivedAt uint64, processedAt uint64, password string) error { dbFile, dbId, dbm, err := FindMessageByUuid(peer, messageUuid, password) if err != nil { return err } if dbm.Status == nil { dbm.Status = &meowlib.ConversationStatus{} } if receivedAt != 0 { dbm.Status.Received = receivedAt } if processedAt != 0 { dbm.Status.Processed = processedAt } return UpdateDbMessage(dbm, dbFile, dbId, password) } func createMessageTable(db *sql.DB) error { stmt, err := db.Prepare(`CREATE TABLE message ( "id" integer NOT NULL PRIMARY KEY AUTOINCREMENT, "m" BLOB)`) if err != nil { return err } stmt.Exec() return nil } func createServerTable(db *sql.DB) error { stmt, err := db.Prepare(`CREATE TABLE servers ( "id" integer NOT NULL PRIMARY KEY AUTOINCREMENT, "country" varchar(2), "public" bool, "uptime" int, "bandwith" float, "load" float, "url" varchar(2000) "name" varchar(255); "description" varchar(5000) "publickey" varchar(10000) )`) if err != nil { return err } stmt.Exec() return nil }