175 lines
6.3 KiB
Go
175 lines
6.3 KiB
Go
package client
|
|
|
|
import (
|
|
"encoding/hex"
|
|
"encoding/json"
|
|
"encoding/base64"
|
|
"fmt"
|
|
|
|
doubleratchet "github.com/status-im/doubleratchet"
|
|
)
|
|
|
|
// drLocalPair implements doubleratchet.DHPair using raw byte slices.
|
|
type drLocalPair struct {
|
|
priv doubleratchet.Key
|
|
pub doubleratchet.Key
|
|
}
|
|
|
|
func (p drLocalPair) PrivateKey() doubleratchet.Key { return p.priv }
|
|
func (p drLocalPair) PublicKey() doubleratchet.Key { return p.pub }
|
|
|
|
// serializedDRState is an intermediate JSON-friendly representation of doubleratchet.State.
|
|
type serializedDRState struct {
|
|
DHrPublic []byte `json:"dhr_pub"`
|
|
DHsPrivate []byte `json:"dhs_priv"`
|
|
DHsPublic []byte `json:"dhs_pub"`
|
|
RootChCK []byte `json:"root_ch_ck"`
|
|
SendChCK []byte `json:"send_ch_ck"`
|
|
SendChN uint32 `json:"send_ch_n"`
|
|
RecvChCK []byte `json:"recv_ch_ck"`
|
|
RecvChN uint32 `json:"recv_ch_n"`
|
|
PN uint32 `json:"pn"`
|
|
MkSkipped map[string]map[uint][]byte `json:"mk_skipped"`
|
|
MaxSkip uint `json:"max_skip"`
|
|
MaxKeep uint `json:"max_keep"`
|
|
MaxMessageKeysPerSession int `json:"max_mks_per_session"`
|
|
Step uint `json:"step"`
|
|
KeysCount uint `json:"keys_count"`
|
|
}
|
|
|
|
// drSessionStorage implements doubleratchet.SessionStorage, persisting state into peer.DrStateJson.
|
|
type drSessionStorage struct{ peer *Peer }
|
|
|
|
func (s *drSessionStorage) Save(id []byte, state *doubleratchet.State) error {
|
|
all, err := state.MkSkipped.All()
|
|
if err != nil {
|
|
return fmt.Errorf("drSessionStorage.Save: MkSkipped.All: %w", err)
|
|
}
|
|
mkSkipped := make(map[string]map[uint][]byte, len(all))
|
|
for k, msgs := range all {
|
|
inner := make(map[uint][]byte, len(msgs))
|
|
for num, mk := range msgs {
|
|
inner[num] = []byte(mk)
|
|
}
|
|
mkSkipped[k] = inner
|
|
}
|
|
|
|
ss := serializedDRState{
|
|
DHrPublic: []byte(state.DHr),
|
|
DHsPrivate: []byte(state.DHs.PrivateKey()),
|
|
DHsPublic: []byte(state.DHs.PublicKey()),
|
|
RootChCK: []byte(state.RootCh.CK),
|
|
SendChCK: []byte(state.SendCh.CK),
|
|
SendChN: state.SendCh.N,
|
|
RecvChCK: []byte(state.RecvCh.CK),
|
|
RecvChN: state.RecvCh.N,
|
|
PN: state.PN,
|
|
MkSkipped: mkSkipped,
|
|
MaxSkip: state.MaxSkip,
|
|
MaxKeep: state.MaxKeep,
|
|
MaxMessageKeysPerSession: state.MaxMessageKeysPerSession,
|
|
Step: state.Step,
|
|
KeysCount: state.KeysCount,
|
|
}
|
|
|
|
b, err := json.Marshal(ss)
|
|
if err != nil {
|
|
return fmt.Errorf("drSessionStorage.Save: json.Marshal: %w", err)
|
|
}
|
|
s.peer.DrStateJson = string(b)
|
|
return nil
|
|
}
|
|
|
|
func (s *drSessionStorage) Load(id []byte) (*doubleratchet.State, error) {
|
|
if s.peer.DrStateJson == "" {
|
|
return nil, nil
|
|
}
|
|
|
|
var ss serializedDRState
|
|
if err := json.Unmarshal([]byte(s.peer.DrStateJson), &ss); err != nil {
|
|
return nil, fmt.Errorf("drSessionStorage.Load: json.Unmarshal: %w", err)
|
|
}
|
|
|
|
c := doubleratchet.DefaultCrypto{}
|
|
mkStorage := &doubleratchet.KeysStorageInMemory{}
|
|
seq := uint(0)
|
|
for k, msgs := range ss.MkSkipped {
|
|
pubKey, err := hex.DecodeString(k)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("drSessionStorage.Load: decode skipped key hex: %w", err)
|
|
}
|
|
for num, mk := range msgs {
|
|
if err := mkStorage.Put(id, doubleratchet.Key(pubKey), num, doubleratchet.Key(mk), seq); err != nil {
|
|
return nil, fmt.Errorf("drSessionStorage.Load: Put: %w", err)
|
|
}
|
|
seq++
|
|
}
|
|
}
|
|
|
|
state := &doubleratchet.State{
|
|
Crypto: c,
|
|
DHr: doubleratchet.Key(ss.DHrPublic),
|
|
DHs: drLocalPair{priv: doubleratchet.Key(ss.DHsPrivate), pub: doubleratchet.Key(ss.DHsPublic)},
|
|
PN: ss.PN,
|
|
MkSkipped: mkStorage,
|
|
MaxSkip: ss.MaxSkip,
|
|
MaxKeep: ss.MaxKeep,
|
|
MaxMessageKeysPerSession: ss.MaxMessageKeysPerSession,
|
|
Step: ss.Step,
|
|
KeysCount: ss.KeysCount,
|
|
}
|
|
state.RootCh.CK = doubleratchet.Key(ss.RootChCK)
|
|
state.RootCh.Crypto = c
|
|
state.SendCh.CK = doubleratchet.Key(ss.SendChCK)
|
|
state.SendCh.N = ss.SendChN
|
|
state.SendCh.Crypto = c
|
|
state.RecvCh.CK = doubleratchet.Key(ss.RecvChCK)
|
|
state.RecvCh.N = ss.RecvChN
|
|
state.RecvCh.Crypto = c
|
|
|
|
return state, nil
|
|
}
|
|
|
|
// GetDRSession returns an active DR session for the peer, creating one if needed.
|
|
func (p *Peer) GetDRSession() (doubleratchet.Session, error) {
|
|
store := &drSessionStorage{peer: p}
|
|
|
|
// If we already have a saved state, load it
|
|
if p.DrStateJson != "" {
|
|
return doubleratchet.Load([]byte(p.Uid), store)
|
|
}
|
|
|
|
// Initiator: has own DH keypair + root key, no state yet
|
|
if p.DrInitiator && p.DrKpPrivate != "" {
|
|
privBytes, err := base64.StdEncoding.DecodeString(p.DrKpPrivate)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("GetDRSession: decode DrKpPrivate: %w", err)
|
|
}
|
|
pubBytes, err := base64.StdEncoding.DecodeString(p.DrKpPublic)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("GetDRSession: decode DrKpPublic: %w", err)
|
|
}
|
|
rootKeyBytes, err := base64.StdEncoding.DecodeString(p.DrRootKey)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("GetDRSession: decode DrRootKey: %w", err)
|
|
}
|
|
kp := drLocalPair{priv: doubleratchet.Key(privBytes), pub: doubleratchet.Key(pubBytes)}
|
|
return doubleratchet.New([]byte(p.Uid), doubleratchet.Key(rootKeyBytes), kp, store)
|
|
}
|
|
|
|
// Responder: has remote DH public key + root key
|
|
if !p.DrInitiator && p.ContactDrPublicKey != "" {
|
|
remotePubBytes, err := base64.StdEncoding.DecodeString(p.ContactDrPublicKey)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("GetDRSession: decode ContactDrPublicKey: %w", err)
|
|
}
|
|
rootKeyBytes, err := base64.StdEncoding.DecodeString(p.DrRootKey)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("GetDRSession: decode DrRootKey: %w", err)
|
|
}
|
|
return doubleratchet.NewWithRemoteKey([]byte(p.Uid), doubleratchet.Key(rootKeyBytes), doubleratchet.Key(remotePubBytes), store)
|
|
}
|
|
|
|
return nil, fmt.Errorf("GetDRSession: peer %s has no DR keys configured", p.Uid)
|
|
}
|