Encrypting with Olm

Now that we have encryption and decryption using Megolm, we need to send the Megolm sessions securely to the recipients. As mentioned earlier, this is done using Olm. An Olm session is a bi-directional encrypted channel between two devices. Unlike Megolm, in which each sender creates their own session, an Olm session allows either side to encrypt. However, Olm only allows encrypts between two devices, whereas a single Megolm session allows for multiple recipients. Since a single Olm session allows us to encrypt our own messages and decrypt the other party’s messages, and we don’t need to worry about decrypting our own messages, we only need one class for Olm, as opposed to the two classes (outbound and inbound) that we needed for Megolm. However, due to the way we will be decrypting, as we will see below, our class will not represent a single Olm session, but rather all the Olm sessions that we have with another device. (We could have multiple sessions with the same device if, for example, we both try to initiate a session and the same time, or if a session got out of sync and needed to be created.)

We will call this grouping of Olm sessions an Olm channel.

src/matrixlib/olm.py:
# {{copyright}}

"""Olm-related functionality"""

import asyncio
from base64 import b64decode, b64encode
import json
import os
import sys
import typing
import vodozemac

from .client import Client
from . import devices
from . import error
from . import rooms
from . import schema


OLM_ALGORITHM = "m.olm.v1.curve25519-aes-sha2"


{{olm module classes}}
olm module classes:
class OlmChannel:
    """Manages a set of Olm sessions with another device"""

    {{OlmChannel member variables}}

    {{OlmChannel class methods}}

There are two ways that an Olm channel can be created: we can initiate the channel ourselves, we can create a channel from a message received from another device, or we can load the channel from storage. As with InboundMegolmSession, we will avoid using the initializer function, and instead create instances of this class using special class methods, in this case create_* methods.

OlmChannel class methods:
def __init__(self):
    """Do not use initializer function.  Use the ``create_*`` methods instead"""
    raise RuntimeError("Use the create_* methods instead")

Initiating a channel

First, we look at initiating a channel which, for the purposes of the create_* method name, we will call an outbound channel (though once the channel is created, there is no difference between an outbound and an inbound channel – the difference is only in the creation).

To initiate channel, we need to initiate an Olm session, and to initiate an Olm session, we will need the recipient’s device keys and a one-time key. We could have our create method automatically obtain a one-time key from the server, but the endpoint for claiming a one-time key allows us to claim keys for multiple recipients at a time, so we will claim the one-time keys outside of our create method. We will create a function in our client class to claim one-time keys below. For now, we just assume that we have obtained the one-time key somehow. As implied by the name, a one-time key should only be used once – after it has been used, it should be discarded. If a new session needs to be created, a new one-time key will need to be claimed.

We also pass into our create function a client so that we can store the sessions and associated data, and a device keys manager as we need to use our vodozemac Account object to create Olm sessions. As usual, we also allow passing in a key to encrypt the session in our storage. If no key is passed in, we use the same key as the key used in the device keys manager.

Before we create the Olm session, we must verify the signature on both the device keys object and the one-time key using the fingerprint key in the device keys, to ensure that they have not been tampered with. Then we can create the session.

Our create function will initialize the object’s member variables, and then call another function, add_outbound_olm_session, to create the Olm session. This function can also be used if we later need to create a new Olm session.

OlmChannel class methods:
@classmethod
def create_outbound_channel(
    cls,
    c: Client,
    device_keys_manager: devices.DeviceKeysManager,
    recipient_device_keys: dict,
    recipient_one_time_key: dict,
    key: typing.Optional[bytes] = None,
) -> "OlmChannel":
    """Create a new channel with a new Olm session

    Arguments:

    ``c``:
      the client object
    ``device_keys_manager``:
      a ``DeviceKeysManager`` object
    ``recipient_device_keys``:
      the other party's device keys, as returned by ``/keys/query``
    ``recipient_one_time_key``:
      the other party's signed one-time key, as returned by ``/keys/claim``
    ``key``:
      a 32-byte binary used to encrypt the objects in storage.  If not
      specified, uses the same key as used by ``device_keys_manager``

    Returns a new ``OlmChannel`` with an Olm session
    """
    obj = cls.__new__(cls)

    obj.client = c
    obj.device_keys_manager = device_keys_manager
    obj.key = key if key else device_keys_manager.key

    obj.partner_user_id = recipient_device_keys["user_id"]
    obj.partner_device_id = recipient_device_keys["device_id"]
    partner_keys = recipient_device_keys["keys"]
    obj.partner_identity_key = partner_keys[f"curve25519:{obj.partner_device_id}"]
    obj.partner_fingerprint_key = partner_keys[f"ed25519:{obj.partner_device_id}"]

    # check the signature on the device keys
    devices.verify_json_ed25519(
        typing.cast(str, obj.partner_fingerprint_key),
        obj.partner_user_id,
        typing.cast(str, obj.partner_device_id),
        recipient_device_keys,
    )

    obj.sessions = []
    # note: add_outbound_olm_session will check the signature on the OTK, so we
    # don't need to check the signature separately
    obj.add_outbound_olm_session(recipient_one_time_key)

    obj._store_session_data()

    return obj

def add_outbound_olm_session(self, recipient_one_time_key: dict) -> None:
    """Add a new Olm session to the channel

    Arguments:

    ``recipient_one_time_key``:
      the other party's signed one-time key, as returned by ``/keys/claim``
    """
    if self.partner_fingerprint_key == None or self.partner_device_id == None:
        raise RuntimeError("Unable to verify signature")

    devices.verify_json_ed25519(
        typing.cast(str, self.partner_fingerprint_key),
        self.partner_user_id,
        typing.cast(str, self.partner_device_id),
        recipient_one_time_key,
    )

    session = self.device_keys_manager.account.create_outbound_session(
        self.partner_identity_key,
        recipient_one_time_key["key"],
    )
    self.sessions.append(session)

def _store_session_data(self, pickle=False) -> None:
    name = f"olm_session.{self.partner_user_id}.{self.partner_identity_key}"
    to_store: dict[str, typing.Any] = {
        "sessions": [session.pickle(self.key) for session in self.sessions]
    }
    if self.partner_device_id != None:
        to_store["partner_device_id"] = self.partner_device_id
    if self.partner_fingerprint_key != None:
        to_store["partner_fingerprint_key"] = self.partner_fingerprint_key
    self.client.storage[name] = to_store

As with our InboundMegolmSession class, we need to declare the member variables to satisfy the type checker.

OlmChannel member variables:
client: Client
device_keys_manager: devices.DeviceKeysManager
key: bytes
partner_user_id: str
partner_device_id: typing.Optional[str]
partner_identity_key: str
partner_fingerprint_key: typing.Optional[str]
sessions: list[vodozemac.Session]

Encrypting

Now that we have created a channel with an Olm session, we can we can use the Olm session to encrypt an event. As we did for Megolm encryption, our encryption function will take an event type and an event content, and return the event content for an m.room.encrypted event to send to the recipient. (Even though the event is not sent to a room, the event type is still called m.room.encrypted.)

To encrypt the event, similar to with Megolm, we construct a dict containing the event type and event content. We do not include the room ID like we did with Megolm, since we do not send Olm messages to a room. We also include some information about our device and the recipient’s device to ensure that an attacker can’t publish someone else’s curve25519 key and using it to send messages that were actually sent by someone else. The recipient will need to check that this information matches the device that it received the message from.

In the case where we have multiple Olm sessions, we the spec says that should use the last session that successfully decrypted a message, or the session that we just created, if any. We will keep this session as the last session in our list of sessions.

The reason that we use the last session that successfully decrypted a message is that it is the session that is most likely to be usable. Sessions can get corrupted, and we will create new sessions to replace them when that happens. If a session has been corrupted, we do not want to use it any more. If we successfully decrypted a message from the other device recently, then it has a smaller chance of being corrupted. Likewise, if we just created a new session, we it should not be corrupt.

Historical note

The ciphertext property in in the return value is an object indexed by the recipient’s identity key. This allows a single event to be sent to multiple recipient. This comes from a previous version of end-to-end encryption in Matrix, in which room messages were encrypted using Olm rather than Megolm and so the room event needed to contain a separate ciphertext for each recipient. Now that Olm-encrypted events are sent only to individual devices, this is no longer necessary, but the format remains.

Unlike in Megolm, where the ciphertext is a single string, with Olm, the ciphertext consists of two parts: a type and an body. The type indicates whether it is a pre-key or a “normal” message. A pre-key message (indicated by the type set to 0) is a message that is encrypted by an Olm session before receiving an encrypted message from the other party using that Olm session, whereas a normal message (indicated by the type set to 1) is one that is encrypted after receiving a message from the other party. When sending a message, there is no real difference between pre-key and normal messages. However, when decrypting, we can only create an Olm session from a pre-key message; normal messages can only be decrypted using existing Olm sessions.

Whenever we perform any operation on an Olm session, we must store the session, so we make sure to do that in our encrypt function too.

OlmChannel class methods:
def encrypt(self, event_type: str, content: dict) -> dict:
    """Encrypt an event using Olm

    Arguments:

    ``event_type``:
      the type of the event (e.g. ``m.room.message``)
    ``content``:
      the event ``content``

    Returns the ``content`` of a ``m.room.encrypted`` event
    """
    if len(self.sessions) == 0:
        raise RuntimeError("No Olm session available")

    plaintext = json.dumps(
        {
            "type": event_type,
            "content": content,
            "sender": self.client.user_id,
            "recipient": self.partner_user_id,
            "recipient_keys": {"ed25519": self.partner_fingerprint_key},
            "keys": {"ed25519": self.device_keys_manager.fingerprint_key},
        }
    )
    # sessions[-1] is the last item in the list
    olm_message = self.sessions[-1].encrypt(plaintext)
    self._store_session_data()

    return {
        "algorithm": OLM_ALGORITHM,
        "sender_key": self.device_keys_manager.identity_key,
        "ciphertext": {
            self.partner_identity_key: {
                "type": olm_message.message_type,
                "body": olm_message.ciphertext,
            },
        },
    }

Decrypting

Decrypting with Olm is a bit more complicated than with Megolm. With Olm, we do not know the ID of the session used to encrypt the message. All we know is the user ID and identity key of the sender, as well as the Olm message type. So to decrypt, we must try decrypting with all the sessions that we have with the other device, and if none of them are able to decrypt and the message is a pre-key message, try to create a new session using the message. This is the reason why we use a single object to represent multiple Olm sessions: because we don’t know which session we should use to decrypt a message.

OlmChannel class methods:
# FIXME: needs more explanation, split into chunks
def decrypt(self, event_content: dict) -> dict:
    """Decrypt an ``m.room.encrypted`` event encrypted with Olm

    Creates a new Olm session if necessary.

    Arguments:

    ``event_content``:
      the ``content`` of the ``m.room.encrypted`` event

    Returns the decrypted event, which will be a dict that should have ``type``
    (the decrypted event type), ``content`` (the event content), and
    information about the sender and recipient.
    """
    {{check cleartext Olm event content}}

    {{try to decrypt with existing Olm session}}

    {{try to create new Olm session from event}}

    {{handle unable to decrypt Olm session}}

First we do some basic sanity checking on the event. We ensure that it is in the expected format, encrypted with Olm, and that we are the intended recipient for the message. (We could also check that the sender_key matches the identity key of the other party in the channel, but we will in fact be looking up the channel based on that property, so checking it would be redundant.)

check cleartext Olm event content:
schema.ensure_valid(
    event_content,
    {
        "algorithm": str,
        "sender_key": str,
        "ciphertext": schema.Object({"type": int, "body": str}),
    },
)

if event_content["algorithm"] != OLM_ALGORITHM:
    raise RuntimeError("Invalid algorithm")

olm_message_dict = event_content.get("ciphertext", {}).get(
    self.device_keys_manager.identity_key
)
if olm_message_dict == None:
    raise RuntimeError("The message is not encrypted for us")

Next, we try to decrypt using one of our existing Olm sessions. This is just a loop through our sessions, and trying to decrypt using the session. If it fails, then we assume that it wasn’t encrypted using that session, and try the next session.

We iterate through our sessions in reverse order, starting from the last session in our list, because in general that is the most likely one to have been used. As mentioned above, we maintain our sessions list so that the last session to successfully decrypt a message is the last item in the list. In practice, this probably does not make much of a difference, but it is also not that much more difficult to iterate in reverse order.

Since we want to maintain our sessions list so that the last session to successfully decrypt is the last item in the list, if we do decrypt with a session, we move it to the end of the list if it is already there. We also need to store our session data after decrypting.

We then parse the decrypted message as JSON and do some more checking on the content using another function that we will write shortly. If all checks out, then we return the plaintext.

try to decrypt with existing Olm session:
olm_message = vodozemac.OlmMessage(
    olm_message_dict["type"], olm_message_dict["body"]
)

for index in range(len(self.sessions) - 1, -1, -1):
    session = self.sessions[index]
    try:
        plaintext_str = session.decrypt(olm_message)
    except:
        continue

    # on successful decryption, make sure it's the last in our list
    if index != len(self.sessions) - 1:
        self.sessions[index], self.sessions[-1] = (
            self.sessions[-1],
            self.sessions[index],
        )
    self._store_session_data()

    plaintext = json.loads(plaintext_str)
    self._check_plaintext(plaintext)

    return plaintext

If we were unable to decrypt using an existing session, then it may be possible that the message was encrypted using a new session, so we need to see if we can create a new session from the message. This can only happen if the message is a pre-key message (indicated by the type property being set to 0 – if it is set to 1 then it is a normal message). So we check the message type and attempt to create a new session from it (which also gives us the plaintext if successful). If successful, we add the new session to our list of sessions, store them, and as above, parse the plaintext as JSON, do some checking on it, and return it.

try to create new Olm session from event:
if olm_message_dict["type"] == 0:
    try:
        (
            session,
            plaintext_str,
        ) = self.device_keys_manager.account.create_inbound_session(
            self.partner_identity_key,
            olm_message,
        )
    except:
        {{handle unable to decrypt Olm session}}

    self.sessions.append(session)
    # FIXME: truncate self.sessions
    self._store_session_data()

    plaintext = json.loads(plaintext_str)
    self._check_plaintext(plaintext)

    return plaintext

If we are unable to decrypt the message, neither using an existing session or by creating a new session, then we raise an exception to indicate this to the application. The application should then try to create a new Olm session, as this indicates that there may be no usable Olm session between us and the other party. We will discuss exactly how to do this later FIXME: link. For now, we will just raise an exception. Note that the decrypt function may raise exceptions in other cases, if validation of the cleartext or plaintext portions fails. However, the application should only create a new Olm session if the failure is in decrypting, and not for any of the other cases, as the other cases do not indicate a problem with the Olm session itself. Thus we create a new exception class for this.

handle unable to decrypt Olm session:
raise error.UnableToDecryptError()
error module classes:
class UnableToDecryptError(RuntimeError):
    """We were unable to decrypt the message using Olm"""

    pass

As mentioned above, we need to check the plaintext after decrypting the message. We check that it has the expected format, that the sender property matches the user ID of the other side of the Olm channel, and that the recipient and recipient_keys property match our own user ID and fingerprint key, respectively. The keys property, however, needs some special treatment. Normally, it should match the sender’s fingerprint key. However, we may not have access to the sender’s fingerprint key: if the sender sends the event and then logs out before we receive the event, and we haven’t previously seen that device’s keys, we won’t be able to obtain the device’s keys, and so we won’t have the fingerprint key. (We still have the identity key, since that is in the cleartext part of the event.) The best we can do in this case is to record the fingerprint key and match it against any subsequent encrypted messages we get from that device. The application, however, should note that it never got the device’s keys, and so was unable to verify the fingerprint key.

OlmChannel class methods:
def _check_plaintext(self, plaintext: dict) -> None:
    schema.ensure_valid(
        plaintext,
        {
            "type": str,
            "content": dict,
            "sender": str,
            "recipient": str,
            "recipient_keys": {"ed25519": str},
            "keys": {"ed25519": str},
        },
    )
    if (
        plaintext["sender"] != self.partner_user_id
        or plaintext["recipient"] != self.client.user_id
        or plaintext["recipient_keys"]["ed25519"]
        != self.device_keys_manager.fingerprint_key
    ):
        raise RuntimeError("Invalid message")
    if self.partner_fingerprint_key:
        if plaintext["keys"]["ed25519"] != self.partner_fingerprint_key:
            raise RuntimeError("Mismatched fingerprint key")
    else:
        self.partner_fingerprint_key = plaintext["keys"]["ed25519"]
        self._store_session_data()

We can now use our decryption function to create a function that creates an OlmChannel in response to an encrypted event. As when creating the outbound channel, we pass it a client object and our device keys manager. We also pass in information the other party, the content of the event that we want to decrypt, and (optionally) a key for encrypting the session in storage. When creating an Olm channel from an encrypted message, the other device’s ID and fingerprint key are optional, as we may be decrypting a message from a device that no longer exists (e.g. it logged out before we received the message). Without the device ID, we will be unable to create a new Olm session in the channel. If the fingerprint key is not provided, it will set the fingerprint key to the one provided in the plaintext after decrypting the message. However, clients should not rely on this, and should provide the fingerprint key when it has the device’s keys available.

On success, our function will return a tuple consisting of the new Olm channel and the decrypted event. If we fail for some reason, the behaviour will depend on whether the failure happened before the actual decryption (for example, if the cleartext event content was invalid), during decryption, or after (if the decrypted event was invalid). If the failure happened before or during decryption, that means that no Olm session was created at all, so the function will raise an exception since the Olm channel isn’t usable. If the failure happened after decryption, that means that an Olm session was created and so the Olm channel is usable. So we will return the Olm channel, but we will also return the exception; we do this by returning a tuple.

OlmChannel class methods:
@classmethod
def create_from_encrypted_message(
    cls,
    c: Client,
    device_keys_manager: devices.DeviceKeysManager,
    partner_user_id: str,
    partner_identity_key: str,
    event_content: dict,
    partner_device_id: typing.Optional[str] = None,
    partner_fingerprint_key: typing.Optional[str] = None,
    key: typing.Optional[bytes] = None,
) -> typing.Tuple["OlmChannel", typing.Union[dict, BaseException]]:
    """Create a new channel from an encrypted message

    Arguments:

    ``c``:
      the client object
    ``device_keys_manager``:
      a ``DeviceKeysManager`` object
    ``partner_user_id``:
      the other party's user ID, as returned by ``/keys/query``
    ``partner_identity_key``:
      the other party's identity key
    ``event_content``:
      the content of the ``m.room.encrypted`` event
    ``partner_device_id``:
      the other party's device ID.  You will not be able to create a new
      outbound Olm session without an device ID.  The device ID can be set
      later by setting the ``OlmChannel`` object's ``partner_device_id``
      property
    ``partner_fingerprint_key``:
      the other party's fingerprint key.  If not provided, will be set to the
      fingerprint key provided in the plaintext.  However, the message may not
      be trusted unless it matches the device key obtained from the server, and
      this key should be provided if it is available.
    ``key``:
      a 32-byte binary used to encrypt the objects in storage.  If not
      specified, uses the same key as used by ``device_keys_manager``

    On success, returns a tuple consisting of the new OlmChannel object and the
    decrypted message.  On failure, either raises an exception, or returns a
    tuple consisting of the new OlmChannel object and the exception, depending
    on whether the OlmChannel could be created.
    """
    obj = cls.__new__(cls)

    obj.client = c
    obj.device_keys_manager = device_keys_manager
    obj.key = key if key else device_keys_manager.key

    obj.partner_user_id = partner_user_id
    obj.partner_device_id = partner_device_id
    obj.partner_identity_key = partner_identity_key
    obj.partner_fingerprint_key = partner_fingerprint_key

    obj.sessions = []

    try:
        decrypted = obj.decrypt(event_content)
    except error.UnableToDecryptError:
        raise
    except:
        if len(obj.sessions) == 0:
            raise
        else:
            e = typing.cast(BaseException, sys.exc_info()[1])
            return obj, e

    return obj, decrypted
Tests
tests/test_olm.py:
# {{copyright}}

import asyncio
import aioresponses
import json
import pytest
import vodozemac

from matrixlib import client
from matrixlib import devices
from matrixlib import error
from matrixlib import olm


{{test olm}}

We test this by creating two clients, called “Alice” and “Bob”. Alice will create an Olm channel and encrypt an event. Bob will take the encrypted event and decrypt it, creating an Olm channel in the process. Bob will then encrypt his own event using the same Olm channel, and Alice will decrypt it. We also test that the channel can handle multiple Olm sessions.

First, we create the two clients.

test olm:
@pytest.mark.asyncio
async def test_olm(mock_aioresponse):
    async with client.Client(
        storage={
            "access_token": "anaccesstoken",
            "user_id": "@alice:example.org",
            "device_id": "ABCDEFG",
        },
        callbacks={},
        base_client_url="https://matrix-client.example.org/_matrix/client/",
    ) as alice:
        async with client.Client(
            storage={
                "access_token": "anaccesstoken",
                "user_id": "@bob:example.org",
                "device_id": "HIJKLMN",
            },
            callbacks={},
            base_client_url="https://matrix-client.example.org/_matrix/client/",
        ) as bob:
            {{olm test}}

Next we need to create the two device keys managers for the clients. We will capture the device keys and one-time keys for both clients that the managers upload.

olm test:
alice_device_keys = None
alice_otks = None

def alice_callback(url, **kwargs):
    nonlocal alice_device_keys, alice_otks
    alice_device_keys = kwargs["json"]["device_keys"]
    alice_otks = kwargs["json"]["one_time_keys"]

    return aioresponses.CallbackResult(
        status=200,
        body='{"one_time_key_counts":{"signed_curve25519":100}}',
        headers={
            "Content-Type": "application/json",
        },
    )

mock_aioresponse.post(
    "https://matrix-client.example.org/_matrix/client/v3/keys/upload",
    callback=alice_callback,
)
alice_device_keys_manager = devices.DeviceKeysManager(alice, b"\x00" * 32)
await asyncio.sleep(0.1)

bob_device_keys = None
bob_otks = None

def bob_callback(url, **kwargs):
    nonlocal bob_device_keys, bob_otks
    bob_device_keys = kwargs["json"]["device_keys"]
    bob_otks = kwargs["json"]["one_time_keys"]

    return aioresponses.CallbackResult(
        status=200,
        body='{"one_time_key_counts":{"signed_curve25519":100}}',
        headers={
            "Content-Type": "application/json",
        },
    )

mock_aioresponse.post(
    "https://matrix-client.example.org/_matrix/client/v3/keys/upload",
    callback=bob_callback,
)
bob_device_keys_manager = devices.DeviceKeysManager(bob, b"\x00" * 32)
await asyncio.sleep(0.1)

Alice takes Bob’s device keys and one of his one-time keys, and creates an Olm channel.

olm test:
# use the first `signed_curve25519` one-time key
otk = [
    key
    for id, key in bob_otks.items()
    if id.startswith("signed_curve25519:")
][0]

alice_channel = olm.OlmChannel.create_outbound_channel(
    alice,
    alice_device_keys_manager,
    bob_device_keys,
    otk,
)

She then encrypts an event and Bob decrypts it. We check that the decrypted event matches what Alice sent.

olm test:
alice_msg_encrypted = alice_channel.encrypt(
    "m.room.message", {"body": "Hello world!"}
)

(
    bob_channel,
    alice_msg_decrypted,
) = olm.OlmChannel.create_from_encrypted_message(
    bob,
    bob_device_keys_manager,
    "@alice:example.org",
    alice_device_keys_manager.identity_key,
    alice_msg_encrypted,
    partner_device_id="ABCDEFG",
    partner_fingerprint_key=alice_device_keys_manager.fingerprint_key,
)

assert alice_msg_decrypted["content"]["body"] == "Hello world!"
# FIXME: check other fields too

Bob then encrypts an event and Alice decrypts it.

olm test:
bob_msg_encrypted = bob_channel.encrypt(
    "m.room.message", {"body": "Bonjour!"}
)

bob_msg_decrypted = alice_channel.decrypt(bob_msg_encrypted)

assert bob_msg_decrypted["content"]["body"] == "Bonjour!"

Bob now creates a new Olm session using one of Alice’s one-time keys, and encrypts another event using the channel. When Alice decrypts that event, we should see that she has two Olm sessions.

olm test:
# use the first `signed_curve25519` one-time key
otk = [
    key
    for id, key in alice_otks.items()
    if id.startswith("signed_curve25519:")
][0]

bob_channel.add_outbound_olm_session(otk)

bob_msg2_encrypted = bob_channel.encrypt(
    "m.room.message", {"body": "Guten Tag!"}
)

bob_msg2_decrypted = alice_channel.decrypt(bob_msg2_encrypted)

assert bob_msg2_decrypted["content"]["body"] == "Guten Tag!"

assert len(alice_channel.sessions) == 2

We also want to ensure that the decrypt function will catch the errors that it needs to.

test olm:
@pytest.mark.asyncio
async def test_olm_error_checking(mock_aioresponse):
    async with client.Client(
        storage={
            "access_token": "anaccesstoken",
            "user_id": "@alice:example.org",
            "device_id": "ABCDEFG",
        },
        callbacks={},
        base_client_url="https://matrix-client.example.org/_matrix/client/",
    ) as alice:
        async with client.Client(
            storage={
                "access_token": "anaccesstoken",
                "user_id": "@bob:example.org",
                "device_id": "HIJKLMN",
            },
            callbacks={},
            base_client_url="https://matrix-client.example.org/_matrix/client/",
        ) as bob:
            {{olm error checking test}}
olm error checking test:
alice_device_keys = None
alice_otks = None

def alice_callback(url, **kwargs):
    nonlocal alice_device_keys, alice_otks
    alice_device_keys = kwargs["json"]["device_keys"]
    alice_otks = kwargs["json"]["one_time_keys"]

    return aioresponses.CallbackResult(
        status=200,
        body='{"one_time_key_counts":{"signed_curve25519":100}}',
        headers={
            "Content-Type": "application/json",
        },
    )

mock_aioresponse.post(
    "https://matrix-client.example.org/_matrix/client/v3/keys/upload",
    callback=alice_callback,
)
alice_device_keys_manager = devices.DeviceKeysManager(alice, b"\x00" * 32)
await asyncio.sleep(0.1)

bob_device_keys = None
bob_otks = None

def bob_callback(url, **kwargs):
    nonlocal bob_device_keys, bob_otks
    bob_device_keys = kwargs["json"]["device_keys"]
    bob_otks = kwargs["json"]["one_time_keys"]

    return aioresponses.CallbackResult(
        status=200,
        body='{"one_time_key_counts":{"signed_curve25519":100}}',
        headers={
            "Content-Type": "application/json",
        },
    )

mock_aioresponse.post(
    "https://matrix-client.example.org/_matrix/client/v3/keys/upload",
    callback=bob_callback,
)
bob_device_keys_manager = devices.DeviceKeysManager(bob, b"\x00" * 32)
await asyncio.sleep(0.1)

We will have Alice create an Olm channel and encrypt various messages, but we will modify the messages in some way so that Bob will not be able to process them. When Bob tries to decrypt, it will return some sort of error. First, we modify the message so that Bob is not a recipient:

olm error checking test:
# use the first `signed_curve25519` one-time key
otk = [
    key
    for id, key in bob_otks.items()
    if id.startswith("signed_curve25519:")
][0]

alice_channel = olm.OlmChannel.create_outbound_channel(
    alice,
    alice_device_keys_manager,
    bob_device_keys,
    otk,
)

alice_msg_encrypted = alice_channel.encrypt(
    "m.room.message", {"body": "Wrong recipient"}
)

ciphertext = alice_msg_encrypted["ciphertext"][
    bob_device_keys_manager.identity_key
]
del alice_msg_encrypted["ciphertext"][bob_device_keys_manager.identity_key]
alice_msg_encrypted["ciphertext"]["something"] = ciphertext

with pytest.raises(
    RuntimeError, match="The message is not encrypted for us"
):
    olm.OlmChannel.create_from_encrypted_message(
        bob,
        bob_device_keys_manager,
        "@alice:example.org",
        alice_device_keys_manager.identity_key,
        alice_msg_encrypted,
        partner_device_id="ABCDEFG",
        partner_fingerprint_key=alice_device_keys_manager.fingerprint_key,
    )

Next, we will modify some of the some of the properties within the plaintext. Since we cannot modify the properties after the plaintext has been encrypted, we will modify Alice’s Olm channel so that it will use incorrect values when it encrypts the event. Another approach would be to modify Bob’s Olm channel so that the values that it expects are different from the ones that Alice supplies – and in fact, this is what we must do for modifying Alice’s fingerprint key since her fingerprint key is managed by her vodozemac account object, which we cannot modify. Note that since these incorrect fields are in the plaintext, an Olm session will be created after the first message, so the Olm channel will be returned along with the exception. So we save the channel and use it to attempt decryption of the rest of the messages.

olm error checking test:
# wrong sender user ID
alice_channel.client.storage["user_id"] = "@carol:example.org"

alice_msg_encrypted = alice_channel.encrypt(
    "m.room.message", {"body": "Wrong sender user ID"}
)

bob_channel, exc = olm.OlmChannel.create_from_encrypted_message(
    bob,
    bob_device_keys_manager,
    "@alice:example.org",
    alice_device_keys_manager.identity_key,
    alice_msg_encrypted,
    partner_device_id="ABCDEFG",
    partner_fingerprint_key=alice_device_keys_manager.fingerprint_key,
)

assert isinstance(exc, RuntimeError)
assert str(exc) == "Invalid message"

# set the value back to the right one
alice_channel.client.storage["user_id"] = "@alice:example.org"

# wrong recipient user ID
alice_channel.partner_user_id = "@carol:example.org"

alice_msg_encrypted = alice_channel.encrypt(
    "m.room.message", {"body": "Wrong recipient user ID"}
)

with pytest.raises(RuntimeError, match="Invalid message"):
    bob_channel.decrypt(alice_msg_encrypted)

# set the value back to the right one
alice_channel.partner_user_id = "@bob:example.org"

# wrong recipient fingerprint key
alice_channel.partner_fingerprint_key = "not the right key"

alice_msg_encrypted = alice_channel.encrypt(
    "m.room.message", {"body": "Wrong recipient fingerprint key"}
)

with pytest.raises(RuntimeError, match="Invalid message"):
    bob_channel.decrypt(alice_msg_encrypted)

# set the value back to the right one
alice_channel.partner_fingerprint_key = (
    bob_device_keys_manager.fingerprint_key
)

# wrong sender fingerprint key
alice_msg_encrypted = alice_channel.encrypt(
    "m.room.message", {"body": "Wrong sender fingerprint key"}
)

bob_channel.partner_fingerprint_key = "a different key"

with pytest.raises(RuntimeError, match="Mismatched fingerprint key"):
    bob_channel.decrypt(alice_msg_encrypted)

# set the value back to the right one
bob_channel.partner_fingerprint_key = (
    alice_device_keys_manager.fingerprint_key
)

Finally, we check that it throws the right exception if it is unable to decrypt an event. We can do that by replacing the ciphertext by garbage.

olm error checking test:
alice_msg_encrypted = alice_channel.encrypt(
    "m.room.message", {"body": "Hello"}
)

alice_msg_encrypted["ciphertext"][bob_device_keys_manager.identity_key] = {
    "type": 0,
    "body": "Cannot decrypt this",
}

with pytest.raises(error.UnableToDecryptError):
    bob_channel.decrypt(alice_msg_encrypted)

Loading a channel from storage

We also need to be able to load a channel from storage so that we can handle client restarts. We look up the channel by user ID and identity key, since that is the information that we have when we receive an encrypted message.

OlmChannel class methods:
@classmethod
def create_from_storage(
    cls,
    c: Client,
    device_keys_manager: devices.DeviceKeysManager,
    partner_user_id: str,
    partner_identity_key: str,
    key: typing.Optional[bytes] = None,
) -> typing.Optional["OlmChannel"]:
    """Loads an Olm channel from storage

    ``c``:
      the client object
    ``device_keys_manager``:
      a ``DeviceKeysManager`` object
    ``partner_user_id``:
      the other party's user ID, as returned by ``/keys/query``
    ``partner_identity_key``:
      the other party's identity key
    ``key``:
      a 32-byte binary used to encrypt the objects in storage.  If not
      specified, uses the same key as used by ``device_keys_manager``
    """
    name = f"olm_session.{partner_user_id}.{partner_identity_key}"
    stored = c.storage.get(name)
    if stored == None:
        return None

    obj = cls.__new__(cls)

    obj.client = c
    obj.device_keys_manager = device_keys_manager
    obj.key = key if key else device_keys_manager.key

    obj.partner_user_id = partner_user_id
    obj.partner_identity_key = partner_identity_key

    obj.partner_device_id = stored.get("partner_device_id")
    obj.partner_fingerprint_key = stored.get("partner_fingerprint_key")
    obj.sessions = [
        vodozemac.Session.from_pickle(session, obj.key)
        for session in stored["sessions"]
    ]

    return obj
Tests
test olm:
@pytest.mark.asyncio
async def test_load_olm_channel(mock_aioresponse):
    async with client.Client(
        storage={
            "access_token": "anaccesstoken",
            "user_id": "@alice:example.org",
            "device_id": "ABCDEFG",
        },
        callbacks={},
        base_client_url="https://matrix-client.example.org/_matrix/client/",
    ) as alice:
        async with client.Client(
            storage={
                "access_token": "anaccesstoken",
                "user_id": "@bob:example.org",
                "device_id": "HIJKLMN",
            },
            callbacks={},
            base_client_url="https://matrix-client.example.org/_matrix/client/",
        ) as bob:
            {{load olm channel test}}

To test this, we first create Olm channels for Alice and Bob as we did in our previous tests.

load olm channel test:
alice_device_keys = None
alice_otks = None

def alice_callback(url, **kwargs):
    nonlocal alice_device_keys, alice_otks
    alice_device_keys = kwargs["json"]["device_keys"]
    alice_otks = kwargs["json"]["one_time_keys"]

    return aioresponses.CallbackResult(
        status=200,
        body='{"one_time_key_counts":{"signed_curve25519":100}}',
        headers={
            "Content-Type": "application/json",
        },
    )

mock_aioresponse.post(
    "https://matrix-client.example.org/_matrix/client/v3/keys/upload",
    callback=alice_callback,
)
alice_device_keys_manager = devices.DeviceKeysManager(alice, b"\x00" * 32)
await asyncio.sleep(0.1)

bob_device_keys = None
bob_otks = None

def bob_callback(url, **kwargs):
    nonlocal bob_device_keys, bob_otks
    bob_device_keys = kwargs["json"]["device_keys"]
    bob_otks = kwargs["json"]["one_time_keys"]

    return aioresponses.CallbackResult(
        status=200,
        body='{"one_time_key_counts":{"signed_curve25519":100}}',
        headers={
            "Content-Type": "application/json",
        },
    )

mock_aioresponse.post(
    "https://matrix-client.example.org/_matrix/client/v3/keys/upload",
    callback=bob_callback,
)
bob_device_keys_manager = devices.DeviceKeysManager(bob, b"\x00" * 32)
await asyncio.sleep(0.1)

otk = [
    key
    for id, key in bob_otks.items()
    if id.startswith("signed_curve25519:")
][0]

alice_channel = olm.OlmChannel.create_outbound_channel(
    alice,
    alice_device_keys_manager,
    bob_device_keys,
    otk,
)

alice_msg_encrypted = alice_channel.encrypt(
    "m.room.message", {"body": "Hello world!"}
)

bob_channel, _ = olm.OlmChannel.create_from_encrypted_message(
    bob,
    bob_device_keys_manager,
    "@alice:example.org",
    alice_device_keys_manager.identity_key,
    alice_msg_encrypted,
    partner_device_id="ABCDEFG",
    partner_fingerprint_key=alice_device_keys_manager.fingerprint_key,
)

Next we re-load Bob’s channel from storage, and ensure that he can still decrypt a message from Alice and encrypt a message to Alice.

load olm channel test:
bob_channel = olm.OlmChannel.create_from_storage(
    bob,
    bob_device_keys_manager,
    "@alice:example.org",
    alice_device_keys_manager.identity_key,
)

alice_msg_encrypted = alice_channel.encrypt(
    "m.room.message", {"body": "Decrypt after load"}
)

bob_msg_encrypted = bob_channel.encrypt(
    "m.room.message", {"body": "Encrypt after load"}
)

alice_msg_decrypted = bob_channel.decrypt(alice_msg_encrypted)

assert alice_msg_decrypted["content"]["body"] == "Decrypt after load"

bob_msg_decrypted = alice_channel.decrypt(bob_msg_encrypted)

assert bob_msg_decrypted["content"]["body"] == "Encrypt after load"

When we receive an encrypted message, and we do not already have a channel with the sender in memory, we will first try to load the channel from storage, and if none is available, we will try to create a new channel from the encrypted message.

Claiming one-time keys

Part of creating an outbound Olm session is obtaining a one-time key from the other party. Above, we ignored how the key was obtained and just assumed that we had it somehow. Unfortunately, we can’t expect the one-time keys to just get dropped on to our lap; we need to request them from the server. This is done through the POST /keys/claim endpoint. This endpoint allows us to request keys from multiple devices at the same time, to reduce the number of requests that we need to make; since we use Olm to distribute Megolm keys to all the recipient devices in a room, we may need to create multiple Olm sessions at once.

The endpoint also takes a timeout parameter. This indicates how long (in milliseconds) our server will wait for a response from remote servers. When we request a one-time key for a user on a different homeserver, our own homeserver will contact the other user’s homeserver to get the one-time key. The other homeserver may be slow to respond for whatever reason. We do not want to hold up creation of other Olm sessions because another server is being too slow. The default, if no timeout is specified, depends on the server, but the recommendation in the Matrix spec for the default is 10 seconds. Note that this does not necessarily mean that our own request to POST /keys/claim will fall within the specified timeout; this is only a timeout for requests to remote servers.

We add a function to our Client class to claim one-time keys. The POST /keys/claim endpoint allows you to claim one-time keys with different algorithms for different devices. It is unlikely that this functionality is actually needed, so for simplicity, we just specify the algorithm that we want to claim for, and the devices.

We may not be able to get a one-time key for each device that we requested. In such cases, the device will simply be missing from the return value. As mentioned above, if the device belongs to a remote homeserver, the homeserver may fail to respond in time. Or the user may not have any one-time keys available. The response to the POST /keys/claim request includes a property that indicates which homeservers failed to respond, so that we can tell why we did not get a one-time key for certain users.

Client class methods:
async def claim_otks(
    self,
    algorithm: str,
    devices: dict[str, list[str]],
    timeout: typing.Optional[int] = None,
) -> typing.Tuple[dict[str, dict[str, typing.Any]], list[str]]:
    """Claim one-time keys for the given devices.

    Arguments:

    ``algorithm``:
      the key algorithm for the one-time keys that we want to request
    ``devices``:
      the devices that we want to request one-time keys for.  This is given as
      a ``dict``, mapping user IDs to a list of device IDs
    ``timeout``:
      how long to wait, in milliseconds, for a response from remote servers

    Returns a tuple consisting of:

    - a dict mapping user IDs to device IDs to (key ID, one-time key) tuple.  A
      requested device may be missing, indicating that no one-time key for that
      device could be obtained
    - a list of remote servers that failed to respond in time
    """

    otk_map = {
        user_id: {device_id: algorithm for device_id in device_ids}
        for user_id, device_ids in devices.items()
    }
    body: dict[str, typing.Any] = {"one_time_keys": otk_map}
    if timeout != None:
        body["timeout"] = timeout
    url = self.url("v3/keys/claim")
    resp = await self.authenticated(self.http_session.post, url, json=body)
    async with resp:
        status, resp_body = await check_response(resp)
        schema.ensure_valid(
            resp_body,
            {
                "one_time_keys": dict[str, dict[str, dict]],
                "failures": schema.Optional(dict),
            },
        )

        one_time_keys = {}
        for user_id, device_otks in resp_body["one_time_keys"].items():
            user_otks = {}
            for device_id, device_otk_map in device_otks.items():
                # `popitem` removes an item from a dict.  The device_otk_map
                # is a map from key ID to key, and should only have one item
                keyid, _ = user_otks[device_id] = device_otk_map.popitem()
                if not keyid.startswith(algorithm + ":"):
                    raise error.InvalidResponseError()
            one_time_keys[user_id] = user_otks

        if "failures" in resp_body:
            # `list` returns a list of the keys in the dict.  The values in the
            # "failures" property are not used
            failures = list(resp_body["failures"])
        else:
            failures = []

        return one_time_keys, failures
Tests

We can write a simple test for this method

test olm:
@pytest.mark.asyncio
async def test_claim_otks(mock_aioresponse):
    async with client.Client(
        storage={
            "access_token": "anaccesstoken",
            "user_id": "@alice:example.org",
            "device_id": "ABCDEFG",
        },
        callbacks={},
        base_client_url="https://matrix-client.example.org/_matrix/client/",
    ) as c:

        {{test claim otks}}
test claim otks:
def callback1(url, **kwargs):
    assert kwargs["json"] == {
        "one_time_keys": {
            "@bob:bob.example": {
                "ABCDEFG": "signed_curve25519",
                "HIJKLMN": "signed_curve25519",
            },
            "@carol:carol.example": {
                "OPQRSTU": "signed_curve25519",
            },
        }
    }
    return aioresponses.CallbackResult(
        status=200,
        body=json.dumps(
            {
                "one_time_keys": {
                    "@bob:bob.example": {
                        "ABCDEFG": {
                            "signed_curve25519:AAAAAA": {
                                "key": "some+key",
                                "signatures": {
                                    "@bob:bob.example": {
                                        "ed25519:ABCDEFG": "some+signature",
                                    },
                                },
                            },
                        },
                        "HIJKLMN": {
                            "signed_curve25519:AAAAAA": {
                                "key": "some+other+key",
                                "signatures": {
                                    "@bob:bob.example": {
                                        "ed25519:ABCDEFG": "some+other+signature",
                                    },
                                },
                            },
                        },
                    },
                },
                "failures": {
                    "carol.example": "ignored",
                },
            }
        ),
        headers={
            "Content-Type": "application/json",
        },
    )

mock_aioresponse.post(
    "https://matrix-client.example.org/_matrix/client/v3/keys/claim",
    callback=callback1,
)

assert await c.claim_otks(
    "signed_curve25519",
    {
        "@bob:bob.example": ["ABCDEFG", "HIJKLMN"],
        "@carol:carol.example": ["OPQRSTU"],
    },
) == (
    {
        "@bob:bob.example": {
            "ABCDEFG": (
                "signed_curve25519:AAAAAA",
                {
                    "key": "some+key",
                    "signatures": {
                        "@bob:bob.example": {
                            "ed25519:ABCDEFG": "some+signature",
                        },
                    },
                },
            ),
            "HIJKLMN": (
                "signed_curve25519:AAAAAA",
                {
                    "key": "some+other+key",
                    "signatures": {
                        "@bob:bob.example": {
                            "ed25519:ABCDEFG": "some+other+signature",
                        },
                    },
                },
            ),
        },
    },
    ["carol.example"],
)