package client import ( "encoding/hex" "encoding/json" "encoding/base64" "fmt" doubleratchet "github.com/status-im/doubleratchet" ) // drLocalPair implements doubleratchet.DHPair using raw byte slices. type drLocalPair struct { priv doubleratchet.Key pub doubleratchet.Key } func (p drLocalPair) PrivateKey() doubleratchet.Key { return p.priv } func (p drLocalPair) PublicKey() doubleratchet.Key { return p.pub } // serializedDRState is an intermediate JSON-friendly representation of doubleratchet.State. type serializedDRState struct { DHrPublic []byte `json:"dhr_pub"` DHsPrivate []byte `json:"dhs_priv"` DHsPublic []byte `json:"dhs_pub"` RootChCK []byte `json:"root_ch_ck"` SendChCK []byte `json:"send_ch_ck"` SendChN uint32 `json:"send_ch_n"` RecvChCK []byte `json:"recv_ch_ck"` RecvChN uint32 `json:"recv_ch_n"` PN uint32 `json:"pn"` MkSkipped map[string]map[uint][]byte `json:"mk_skipped"` MaxSkip uint `json:"max_skip"` MaxKeep uint `json:"max_keep"` MaxMessageKeysPerSession int `json:"max_mks_per_session"` Step uint `json:"step"` KeysCount uint `json:"keys_count"` } // drSessionStorage implements doubleratchet.SessionStorage, persisting state into peer.DrStateJson. type drSessionStorage struct{ peer *Peer } func (s *drSessionStorage) Save(id []byte, state *doubleratchet.State) error { all, err := state.MkSkipped.All() if err != nil { return fmt.Errorf("drSessionStorage.Save: MkSkipped.All: %w", err) } mkSkipped := make(map[string]map[uint][]byte, len(all)) for k, msgs := range all { inner := make(map[uint][]byte, len(msgs)) for num, mk := range msgs { inner[num] = []byte(mk) } mkSkipped[k] = inner } ss := serializedDRState{ DHrPublic: []byte(state.DHr), DHsPrivate: []byte(state.DHs.PrivateKey()), DHsPublic: []byte(state.DHs.PublicKey()), RootChCK: []byte(state.RootCh.CK), SendChCK: []byte(state.SendCh.CK), SendChN: state.SendCh.N, RecvChCK: []byte(state.RecvCh.CK), RecvChN: state.RecvCh.N, PN: state.PN, MkSkipped: mkSkipped, MaxSkip: state.MaxSkip, MaxKeep: state.MaxKeep, MaxMessageKeysPerSession: state.MaxMessageKeysPerSession, Step: state.Step, KeysCount: state.KeysCount, } b, err := json.Marshal(ss) if err != nil { return fmt.Errorf("drSessionStorage.Save: json.Marshal: %w", err) } s.peer.DrStateJson = string(b) return nil } func (s *drSessionStorage) Load(id []byte) (*doubleratchet.State, error) { if s.peer.DrStateJson == "" { return nil, nil } var ss serializedDRState if err := json.Unmarshal([]byte(s.peer.DrStateJson), &ss); err != nil { return nil, fmt.Errorf("drSessionStorage.Load: json.Unmarshal: %w", err) } c := doubleratchet.DefaultCrypto{} mkStorage := &doubleratchet.KeysStorageInMemory{} seq := uint(0) for k, msgs := range ss.MkSkipped { pubKey, err := hex.DecodeString(k) if err != nil { return nil, fmt.Errorf("drSessionStorage.Load: decode skipped key hex: %w", err) } for num, mk := range msgs { if err := mkStorage.Put(id, doubleratchet.Key(pubKey), num, doubleratchet.Key(mk), seq); err != nil { return nil, fmt.Errorf("drSessionStorage.Load: Put: %w", err) } seq++ } } state := &doubleratchet.State{ Crypto: c, DHr: doubleratchet.Key(ss.DHrPublic), DHs: drLocalPair{priv: doubleratchet.Key(ss.DHsPrivate), pub: doubleratchet.Key(ss.DHsPublic)}, PN: ss.PN, MkSkipped: mkStorage, MaxSkip: ss.MaxSkip, MaxKeep: ss.MaxKeep, MaxMessageKeysPerSession: ss.MaxMessageKeysPerSession, Step: ss.Step, KeysCount: ss.KeysCount, } state.RootCh.CK = doubleratchet.Key(ss.RootChCK) state.RootCh.Crypto = c state.SendCh.CK = doubleratchet.Key(ss.SendChCK) state.SendCh.N = ss.SendChN state.SendCh.Crypto = c state.RecvCh.CK = doubleratchet.Key(ss.RecvChCK) state.RecvCh.N = ss.RecvChN state.RecvCh.Crypto = c return state, nil } // GetDRSession returns an active DR session for the peer, creating one if needed. func (p *Peer) GetDRSession() (doubleratchet.Session, error) { store := &drSessionStorage{peer: p} // If we already have a saved state, load it if p.DrStateJson != "" { return doubleratchet.Load([]byte(p.Uid), store) } // Initiator: has own DH keypair + root key, no state yet if p.DrInitiator && p.DrKpPrivate != "" { privBytes, err := base64.StdEncoding.DecodeString(p.DrKpPrivate) if err != nil { return nil, fmt.Errorf("GetDRSession: decode DrKpPrivate: %w", err) } pubBytes, err := base64.StdEncoding.DecodeString(p.DrKpPublic) if err != nil { return nil, fmt.Errorf("GetDRSession: decode DrKpPublic: %w", err) } rootKeyBytes, err := base64.StdEncoding.DecodeString(p.DrRootKey) if err != nil { return nil, fmt.Errorf("GetDRSession: decode DrRootKey: %w", err) } kp := drLocalPair{priv: doubleratchet.Key(privBytes), pub: doubleratchet.Key(pubBytes)} return doubleratchet.New([]byte(p.Uid), doubleratchet.Key(rootKeyBytes), kp, store) } // Responder: has remote DH public key + root key if !p.DrInitiator && p.ContactDrPublicKey != "" { remotePubBytes, err := base64.StdEncoding.DecodeString(p.ContactDrPublicKey) if err != nil { return nil, fmt.Errorf("GetDRSession: decode ContactDrPublicKey: %w", err) } rootKeyBytes, err := base64.StdEncoding.DecodeString(p.DrRootKey) if err != nil { return nil, fmt.Errorf("GetDRSession: decode DrRootKey: %w", err) } return doubleratchet.NewWithRemoteKey([]byte(p.Uid), doubleratchet.Key(rootKeyBytes), doubleratchet.Key(remotePubBytes), store) } return nil, fmt.Errorf("GetDRSession: peer %s has no DR keys configured", p.Uid) }