Table of Contents

Lab 08 - Whatsapp End-to-end Encryption

In this lab you will implement a simplified version of The Signal Protocol, which is the basis for WhatsApp's end-to-end encryption.

The first versions of Whatsapp protocol were described here. A more recent document is available here. WhatsApp's security is based on the Signal protocol, which was first used by TextSecure. The Signal protocol is described in detail in this paper.

For the Elliptic Curves, you can use this library.

For installation, follow these steps:

Task 1

Create a common master_secret for two clients which communicate through a server. (TODO 1.1 & TODO 1.2) Print it on both clients and make sure they both have the same secret.

How to run

Open three different terminals.

First terminal:


Second terminal:


Third terminal:

MSG <id_other_client> Hello!

Second terminal:

from wa_client import Client
from wa_server import Server
if __name__ == '__main__':
	while True:
		cmd = input('MSG <user_id> <message>\n')
		cmd = cmd.split(" ")
		if cmd[0] == "MSG":
			user_id = int(cmd[1])
			msg = " ".join(cmd[2:])
			c.send_message(user_id, msg)
		elif cmd[0] == "RECV":
			msg = c.recv_message()
from wa_client import Client
from wa_server import Server
if __name__ == '__main__':
	s = Server(SERVER_PORT)
from donna25519 import *
import os
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
import struct
import socket
import hashlib
import hmac
from math import ceil
O_NUM = 10 # number of initial One-Time Pre Keys
hash_len = 32
def hmac_sha256(key, data):
    return, data, hashlib.sha256).digest()
def hkdf(length, ikm, salt = b""):
    length: output length in bytes
    ikm: input key material
    prk = hmac_sha256(salt, ikm)
    t = b""
    okm = b""
    for i in range(int(ceil(1.0 * length / hash_len))):
        t = hmac_sha256(prk, t + bytes([1+i]))
        okm += t
    return okm[:length]
class ClientSession:
    def __init__(self, root_key, chain_key_s, chain_key_r, eph_own, eph_peer):
        self.root_key = root_key
        self.chain_key_s = chain_key_s
        self.chain_key_r = chain_key_r
        self.eph_own = eph_own
        self.eph_peer = eph_peer
    def update_keys(self, eph_peer):
        """Performs vertical ratcheting
        Updates root & chain keys to match peer's (for decrypting its messages)
        Then updates root & chain keys again with fresh ephemereal key (for encrypting
        my messages)
        self.eph_peer = eph_peer
        # Update RootKey & receiving ChainKey
        # TODO 2 Obtain new root_key and chain_key_r from hkdf over
        # DH(eph_own, eph_peer) and root_key (can be used as salt parameter in hkdf)
        self.root_key = ""
        self.chain_key_r = ""
        # Update RootKey & sending ChainKey
        self.eph_own = None # TODO 2 Generate new ephemereal key pair
        # TODO 2 Obtain new root_key and chain_key_s from hkdf over
        # DH(eph_own, eph_peer) and root_key (can be used as salt parameter in hkdf)
        self.root_key = ""
        self.chain_key_s = ""
class Client:
    def __init__(self, server_ip, server_port):
        self.I = PrivateKey() # Identity Key Pair
        self.S = PrivateKey() # Signed Pre Key
        self.O_queue = [PrivateKey() for i in range(O_NUM)] # One-Time Pre Keys
        self.s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        self.s.connect((server_ip, server_port))
        self.user_id = self.register()
        # self.existing_sessions[user_id] = session with that user (ClientSession)
        self.existing_sessions = {} # initially no existing session
    def register(self):
        """Lets the server know about its presence"""
        # TODO 1.1 send public keys of I, S and the list of O to the server
        # e.g. self.s.send(self.I.get_public().public)
        # You may also print a log message, e.g. print "Sent I = " + self.I.get_public().public.hex()
        # You can send the list of O by first sending its length and then the keys one by one
        # use struct.pack('!i', my_int) to send 4 byte signed integers
        # receive your user id from the server (might need later)
        # use struct.unpack('!i', received_int)[0] to get an integer
        # e.g. struct.unpack('!i', self.s.recv(4))[0]
        user_id = 0 # TODO 1.1 change with received user id
        print("Got User ID = %d" % user_id)
        return user_id
    def get_send_message_key(self, peer_user_id):
        """Gets sending MessageKey using ChainKey from peer_user_id's ClientSession"""
        session = self.existing_sessions[peer_user_id]
        # TODO 2 Generate message_key from chain_key_s (session.chain_key_s)
        # message_key = HMAC_SHA256(ChainKey, 0x01)
        message_key = None
        # TODO 2 Update chain_key_s
        # ChainKey = HMAC_SHA256(ChainKey, 0x02)
        session.chain_key_s = None
        return message_key
    def get_recv_message_key(self, peer_user_id):
        """Gets MessageKey using ChainKey from peer_user_id's ClientSession"""
        session = self.existing_sessions[peer_user_id]
        # TODO 2 Generate message_key from chain_key_r (session.chain_key_r)
        # message_key = HMAC_SHA256(ChainKey, 0x01)
        message_key = None
        # TODO 2 Update chain_key_r
        # ChainKey = HMAC_SHA256(ChainKey, 0x02)
        session.chain_key = None
        return message_key
    def pad_message(self, message):
        pad_number = (-len(message)) % 16
        if pad_number == 0:
            pad_number += 16
        message += ''.join([chr(pad_number) for i in xrange(pad_number)])
        return message
    def unpad_message(self, message):
        pad_number = ord(message[-1])
        message = message[:-pad_number]
        return message
    def send_message(self, to_user_id, message):
        """Sends a message to user with id <to_user_id>
        If there is no existing ClientSession between them, they create one
        if to_user_id not in self.existing_sessions:
        self.s.send(struct.pack('!i', to_user_id))
        # Get MessageKey for next message
        message_key = self.get_send_message_key(to_user_id)
        # Encrypt message using MessageKey
        iv = os.urandom(16)
        cipher = Cipher(algorithms.AES(message_key), modes.CBC(iv))
        encryptor = cipher.encryptor()
        message = self.pad_message(message)
        enc_message = encryptor.update(message) + encryptor.finalize()
        # TODO 2 Send own ephemereal public key (get it from the ClientSession with to_user_id)
        session = self.existing_sessions[to_user_id]
        raw_eph_own = session.eph_own.get_public().public
        # TODO 2 Send encrypted message length to server
        # TODO 2 Send encrypted message to server
    def recv_message(self):
        """Receives a message
        It can be either from a new session exchange (tag=NEWS), or an actual message(tag=MESG)"""
        tag = self.s.recv(4)
        if tag == "NEWS": # NEW Session
            ini_user_id = struct.unpack('!i', self.s.recv(4))[0]
            tag = self.s.recv(4)
        if tag == "MESG": # message incoming
            # Get sender user id
            sender_user_id = struct.unpack('!i', self.s.recv(4))[0]
            # Get ephemereal key
            raw_eph_peer = self.s.recv(32)
            session = self.existing_sessions[sender_user_id]
            # TOD2 3 If received eph_peer is unknown (different than the one stored in the
            # ClientSession), update the keys
            # The first time current eph_peer will be None, so make sure you update in that
            # case also
            if False: # change with actual condition
                eph_peer = PublicKey(raw_eph_peer)
            # Get message length
            mlen = struct.unpack('!i', self.s.recv(4))[0]
            # Get message
            enc_message = self.s.recv(mlen)
            # Decrypt message using MessageKey
            message_key = self.get_recv_message_key(sender_user_id)
            iv = enc_message[:16]
            enc_message = enc_message[16:]
            cipher = Cipher(algorithms.AES(message_key), modes.CBC(iv))
            decryptor = cipher.decryptor()
            msg = decryptor.update(enc_message) + decryptor.finalize()
            msg = self.unpad_message(msg)
            return msg
    def record_new_client_session(self, master_secret, client_id, eph_peer = None):
        """Forges RootKey and ChainKey from master_secret and stores them in a ClientSession
        which is then put in the existing_sessions dictionary"""
        root_key = master_secret # master_secret is considered as the base root key
        if eph_peer: # initiator considers Srec as initial peer ephemereal key
            eph_own = PrivateKey() # Generate new send ephemereal key
            # TODO 2 Get root_key and chain_key_s with hkdf over DH(eph_own, eph_peer) and
            # master secret
            root_key = ""
            chain_key_s = ""
            chain_key_r = None # Will initialize with first received message
        else: # the non-initiator (first receiver) will enter this branch
            eph_own = self.S # initiator considers receiver's S as first "ephemereal" key
            chain_key_s = None
            chain_key_r = None
        new_client_session = ClientSession(root_key, chain_key_s, chain_key_r, eph_own, eph_peer)
        self.existing_sessions[client_id] = new_client_session
    def setup_session(self, to_user_id):
        """Initiates a new session with to_user_id"""
        self.s.send(b"GENS") # GENerate Session command
        self.s.send(struct.pack('!i', to_user_id))
        resp = self.s.recv(4)
        if resp == "FAIL":
            print("User does not exist")
            return False
        # TODO 1.1 Get Irec (done), Srec, Orec from server
        Irec = PublicKey(self.s.recv(32))
        # Srec = ...
        # Prec = ...
        # TODO 1.1 Generate Eini (ephemeral private key of initiator)
        # Eini = ... (see above for function to retrieve ephemeral private key)
        # TODO 1.1 Compute master_secret from DH(Iini, Srec), DH(Eini, Irec), DH(Eini, Srec), DH(Eini, Orec)
        # See method do_exchange in donna library:
        # Use '+' to concatenate output of various DH exchanges
        master_secret = ""
        # TODO 1.2 Send Eini to the other client (through server)
        # Generate new ClientSession using master_secret
        self.record_new_client_session(master_secret, to_user_id, Srec)
    def setup_session_rec(self, from_user_id):
        """Receives a new session with from_user_id"""
        # TODO 1.2 Receive initiator's public keys Eini, Iini and Orec_pub
        Eini = PublicKey(self.s.recv(32))
        # Iini = ...
        # Orec_pub = ...
        # TODO 1.2 Get private Orec from my O_queue based on received public Orec and then remove from O_queue
        # TODO 1.2 Compute master_secret from DH(Iini, Srec), DH(Eini, Irec), DH(Eini, Srec), DH(Eini, Orec)
        # See above, in method setup_session
        master_secret = ""
        # Generate new ClientSession using master_secret
        self.record_new_client_session(master_secret, from_user_id)
from donna25519 import *
import socket
import threading
import struct
TCP_IP = ''
class ClientRecord:
    """Structure for keeping record of a client's public keys"""
    def __init__(self, I, S, O_queue, clientsocket):
        self.I = I
        self.S = S
        self.O_queue = O_queue
        self.s = clientsocket
class Server:
    def __init__(self, port):
        self.port = port
        self.next_id = 1
        self.registered_clients = {}
    def start(self):
        """Start listening to client connections"""
        print("Starting server on port %d" % self.port)
        s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        s.bind((TCP_IP, self.port))
        while True:
            c, addr = s.accept() # connection with new client
            print("New client from address " + str(addr))
            threading.Thread(target = self.on_new_client, args = (c, addr)).start()
    def on_new_client(self, clientsocket, addr):
        """Runs when a new client connects to the server"""
        I = PublicKey(clientsocket.recv(32))
        print("Received I = " + I.public.hex())
        S = PublicKey(clientsocket.recv(32))
        print("Received S = " + S.public.hex())
        o_num = struct.unpack('!i', clientsocket.recv(4))[0]
        O_queue = []
        for i in xrange(o_num):
            print("Received O = " + O_queue[i].public.hex())
        user_id = self.register_client(I, S, O_queue, clientsocket)
        clientsocket.send(struct.pack('!i', user_id))
        while True:
            cmd = clientsocket.recv(4)
            if cmd == "GENS": # GENerate Session
                raw_id = clientsocket.recv(4)
                rec_user_id = struct.unpack('!i', raw_id)[0]
                if rec_user_id not in self.registered_clients:
                # Send Irec, Srec, Orec to initiator
                Irec = self.registered_clients[rec_user_id].I
                Srec = self.registered_clients[rec_user_id].S
                Orec = self.registered_clients[rec_user_id].O_queue.pop()
                # get Eini from initiator
                Eini = PublicKey(clientsocket.recv(32))
                rec_clientsocket = self.registered_clients[rec_user_id].s
                rec_clientsocket.send(b"NEWS") # NEW Session
                rec_clientsocket.send(struct.pack('!i', user_id))
                # Send Eini, Iini, Orec to recipient
            if cmd == "SEND": # SEND Message
                # Get recipient used id
                raw_id = clientsocket.recv(4)
                rec_user_id = struct.unpack('!i', raw_id)[0]
                rec_clientsocket = self.registered_clients[rec_user_id].s
                rec_clientsocket.send(b"MESG") # send message tag
                # Send sender user id
                rec_clientsocket.send(struct.pack('!i', user_id))
                # Forward sender ephemereal key
                eph_s = clientsocket.recv(32)
                # Forward message length
                raw_mlen = clientsocket.recv(4)
                # Forward message
                mlen = struct.unpack('!i', raw_mlen)[0]
                raw_msg = clientsocket.recv(mlen)
    def register_client(self, I, S, O_queue, clientsocket):
        """Registers new client's I, S, and O_queue and returns the new client's id"""
        new_client = ClientRecord(I, S, O_queue, clientsocket)
        user_id = self.next_id
        self.next_id += 1
        self.registered_clients[user_id] = new_client
        return user_id