Device Tracking

Now that we have uploaded our own keys to the server, we need to get the keys for other devices. When we send a message, we must know the public keys for all the devices that are in the room. We know what users are in the room, because we can track the m.room.member state events. From there, we can call the POST /keys/query endpoint to get users’ devices and their public keys.

However, we need to ensure that every time we send a message, our list of device keys stays up to date, so that new devices will be able to decrypt the message, and we won’t try to send keys to old devices. It would be inefficient to have to query the device keys of every user in the room, every time we send a message. Instead, the server will tell us, via the GET /sync response, when a user’s devices have changed. We can then keep track of the users who have device changes and only need to query the users who have had changes.

The server will notify us when the devices change for a user that we share an encrypted room with. This is done through the device_lists property of the GET /sync response. It is an object with two properties:

  • changed, which lists users who have updated their device keys (including adding or removing devices), or who now share an encrypted room with the user since the last call to GET /sync; and

  • left, which lists the users who no longer share an encrypted room with the user since the last call to GET /sync.

sync schema:
"device_lists": schema.Optional(
    {
        "changed": schema.Optional(list[str]),
        "left": schema.Optional(list[str]),
    }
),

When we encounter this in the sync, we will publish a message that contains the contents of the device_lists property.

process sync response:
if "device_lists" in body:
    await self.publisher.publish(
        DeviceChanges(
            body["device_lists"].get("changed", []),
            body["device_lists"].get("left", []),
        )
    )
client module classes:
class DeviceChanges(typing.NamedTuple):
    """A message indicating that user's devices have changed"""

    changed: list[str]
    left: list[str]


DeviceChanges.changed.__doc__ = "Users whose devices have changed"
DeviceChanges.left.__doc__ = "Users who no longer share a room"

We will now create a DeviceTracker class that we can query for user’s device keys. It will use the client’s storage to cache devices that we’ve previously queried, and it will subscribe to the DeviceChanges messages to know what users it needs to re-fetch devices for.

For this class, we will start with a simple version, and then extend it to handle some edge cases, support querying by public keys, and handle cross-signing, which we will explain later on. Thus the code chunks presented in this section may include references to chunks defined in other sections. These can be safely ignored for now.

devices module classes:
class DeviceTracker:
    """Tracks user devices to ensure that we're up to date"""

    {{DeviceTracker class methods}}
DeviceTracker class methods:
def __init__(self, c: client.Client):
    """
    Arguments:

    ``c``:
      the client object
    """
    self.client = c
    c.subscribe(client.DeviceChanges, self._subscriber)

    {{DeviceTracker initialization}}

We will also create a lock (also called a mutex) to ensure that concurrent calls do not conflict in accessing the shared data.

DeviceTracker initialization:
self.lock = asyncio.Lock()

We will keep track of the users that we are getting updates for (we add users when they show up in the changed list, and remove them when they show up in the left list). This will help us determine whether to cache results: if we are not getting updates for a user, then we should not cache the results for that user. Since we won’t be notified when their devices change, we should err on the safe side and re-fetch their keys the next time we need them. As well, when a user is listed in changed or left, we will clear their cached data so that we will re-fetch their device keys the next time we need them.

DeviceTracker class methods:
async def _subscriber(self, changes: client.DeviceChanges) -> None:
    async with self.lock:
        tracked_users = self.client.storage.get(
            "device_tracker.tracked_users",
            {self.client.user_id: True},  # We always track our own devices
        )

        for user in changes.changed:
            tracked_users[user] = True
            self._delete_user_device_keys(user)
        for user in changes.left:
            if user in tracked_users:
                del tracked_users[user]
            self._delete_user_device_keys(user)

        self.client.storage["device_tracker.tracked_users"] = tracked_users
DeviceTracker class methods:
def _delete_user_device_keys(self, user):
    user_key = f"device_tracker.cache.{user}"
    if user_key in self.client.storage:
        del self.client.storage[user_key]

    {{mark in-flight device key requests}}

(As mentioned above, the “mark in-flight device key requests” chunks are for dealing with some edge cases and will be discussed below. They can be ignored for now.)

Now we create a function to get the device keys for users. It will take an argument to force re-downloading all the keys, ignoring the cache. It will also take a timeout parameter which will limit the time that the homeserver will wait for a response from remote servers.

DeviceTracker class methods:
async def get_device_keys(
    self,
    users: typing.Iterable[str],
    force_download=False,
    timeout: typing.Optional[int] = None,
) -> dict[str, dict[str, dict]]:
    """Get the device keys for the given users.

    Arguments:

    ``users``:
      the user IDs to fetch device keys for
    ``force_download``:
      whether to ignore the cache and force downloading of all device keys from
      the server
    ``timeout``:
      a timeout in milliseconds for the homeserver to wait for responses from
      remote homeservers

    Returns a dict mapping the user IDs to a dict with the following keys:

    ``device_keys``:
      a dict mapping the user's device IDs to device keys

    FIXME: add cross-signing keys
    """
    ret = {}
    users_needing_download: dict[str, list] = {}

    async with self.lock:
        {{load devices from cache}}

        {{check in-flight device key requests}}

    if users_needing_download != {}:
        {{download and cache device keys}}

    {{get in-flight device key results}}

    return ret

Todo

validate device keys

  • make sure user_id and device_id properties match

  • if signed, make sure signature matches

  • make sure ed25519 key doesn’t change

(Again, the two chunks referring to “in-flight device key requests” are for handling some edge cases described below.)

If the argument to force downloading is set to False, we will try to obtain the device keys from the cache and keep track of the users that it doesn’t have the keys for. Our users_needing_download variable, which keeps track of the users that we need to download keys for, will be formatted in the way that GET /keys/query expects, to save as from converting it later on: it will be an object mapping from the user ID to an array. The array is a list indicating the devices that we want the keys for, or an empty array if we want all of a user’s device keys. Since the latter is what we want, we set all values to the empty array.

load devices from cache:
if not force_download:
    for user in users:
        storage_key = f"device_tracker.cache.{user}"
        if storage_key in self.client.storage:
            ret[user] = self.client.storage[storage_key]
        else:
            users_needing_download[user] = []
else:
    users_needing_download = {user: [] for user in users}

We then download the device keys from the server, update our cache for the users that we are tracking, and then return the keys. The response from the server will be an object with several possible properties. The properties that we will be interested in here are:

  • device_keys is a an object mapping from user ID to an object mapping device ID to device key.

  • failures is an object in which the keys are server names indicating that the remote server could not be contacted to provide the user’s keys. The values in the object do not matter. If a server is included in failures, the users will not have any device information given. In this case, our function will return a result indicating that the users do not have any devices, but will not cache this result so that it will retry later.

The other properties in the response are related to cross-signing, which will be discussed later on.

Todo

also need to maintain map of algorithm -> pubkey -> (user, device)

download and cache device keys:
req_body: dict[str, typing.Any] = {"device_keys": users_needing_download}
if timeout != None:
    req_body["timeout"] = timeout
try:
    resp = await self.client.authenticated(
        self.client.http_session.post,
        self.client.url("v3/keys/query"),
        json=req_body,
    )
    status, resp_body = await client.check_response(resp)
    schema.ensure_valid(
        resp_body,
        {
            "device_keys": schema.Optional(dict[str, dict[str, dict]]),
            "failures": schema.Optional(dict),
            # FIXME: cross-signing keys
        },
    )
except:
    {{clean up in-flight device key results}}

    raise

user_device_keys = resp_body.get("device_keys", {})
failures = resp_body.get("failures", {})
tracked_users = self.client.storage.get("device_tracker.tracked_users", {})

async with self.lock:
    for user in users_needing_download.keys():
        user_device_info = {}
        device_keys = user_device_keys.get(user, {})
        user_device_info["device_keys"] = device_keys
        # FIXME: process cross-signing keys

        self._cache_result(user, tracked_users, failures, user_device_info)
        ret[user] = user_device_info

        {{set result for in-flight device key request}}
DeviceTracker class methods:
def _cache_result(self, user, tracked_users, failures, user_device_info):
    if self._should_cache_result(user, tracked_users, failures):
        self.client.storage[f"device_tracker.cache.{user}"] = user_device_info

(Again, the two chunks referring to “in-flight device key requests” are for handling some edge cases described below.)

We create a function that tells us whether the results should be cached. As mentioned above, we don’t cache results from users who aren’t being tracked, or if the user’s homeserver could not be contacted.

DeviceTracker class methods:
def _should_cache_result(
    self, user: str, tracked_users: dict, failures: dict
) -> bool:
    {{check do not cache flag}}

    if user not in tracked_users:
        return False

    # find the user's server name
    split_user_id = user.split(":", 1)
    if len(split_user_id) != 2:
        # invalid user ID, so don't cache
        return False
    if split_user_id[1] in failures:
        return False

    return True

(The “check do not cache flag” chunk is for handling edge cases described below.)

Tests
tests/test_device_tracking.py:
# {{copyright}}

import asyncio
import aioresponses
import json
import pytest

from matrixlib import client
from matrixlib import devices


{{test device tracking}}

To test the device tracker, we will pre-populate our device cache, and first ensure that when we try to get the device keys, we get the cached values.

test device tracking:
@pytest.mark.asyncio
async def test_basic_device_tracking(mock_aioresponse):
    async with client.Client(
        storage={
            "access_token": "anaccesstoken",
            "user_id": "@alice:example.org",
            "device_id": "ABCDEFG",
            "device_tracker.tracked_users": {
                "@bob:example.org": True,
                "@carol:example.org": True,
            },
            "device_tracker.cache.@bob:example.org": {
                "device_keys": {
                    "HIJKLMN": {
                        "algorithms": [],
                        "device_id": "HIJKLMN",
                        "keys": {
                            "curve25519:HIJKLMN": "some+key",
                        },
                        "user_id": "@bob:example.org",
                    },
                },
            },
            "device_tracker.cache.@carol:example.org": {
                "device_keys": {
                    "OPQRSTU": {
                        "algorithms": [],
                        "device_id": "OPQRSTU",
                        "keys": {
                            "curve25519:OPQRSTU": "some+other+key",
                        },
                        "user_id": "@carol:example.org",
                    },
                },
            },
        },
        callbacks={},
        base_client_url="https://matrix-client.example.org/_matrix/client/",
    ) as c:
        {{basic device tracking test}}
basic device tracking test:
tracker = devices.DeviceTracker(c)

assert await tracker.get_device_keys(
    ["@bob:example.org", "@carol:example.org"]
) == {
    "@bob:example.org": {
        "device_keys": {
            "HIJKLMN": {
                "algorithms": [],
                "device_id": "HIJKLMN",
                "keys": {
                    "curve25519:HIJKLMN": "some+key",
                },
                "user_id": "@bob:example.org",
            },
        },
    },
    "@carol:example.org": {
        "device_keys": {
            "OPQRSTU": {
                "algorithms": [],
                "device_id": "OPQRSTU",
                "keys": {
                    "curve25519:OPQRSTU": "some+other+key",
                },
                "user_id": "@carol:example.org",
            },
        },
    },
}

We then create a sync response that will indicate that Bob’s devices have been updated.

basic device tracking test:
mock_aioresponse.get(
    "https://matrix-client.example.org/_matrix/client/v3/sync?timeout=30000",
    status=200,
    body=json.dumps(
        {
            "device_lists": {
                "changed": ["@bob:example.org"],
            },
            "next_batch": "token1",
        }
    ),
    headers={
        "content-type": "application/json",
    },
)
mock_aioresponse.get(
    "https://matrix-client.example.org/_matrix/client/v3/sync?since=token1&timeout=30000",
    status=200,
    body='{"next_batch":"token1"}',
    headers={
        "content-type": "application/json",
    },
    repeat=True,
)

def subscriber(msg) -> None:
    c.stop_sync()

c.subscribe(client.DeviceChanges, subscriber)
c.subscribe(client.SyncFailed, subscriber)

c.start_sync()

try:
    await c.sync_task
except asyncio.CancelledError:
    pass

We now ensure that when we try to get the device keys, it refreshes Bob’s keys and returns the new keys. We use a callback in our HTTP request mock handler to ensure that the request body only requests Bob’s keys.

basic device tracking test:
def callback(url, **kwargs):
    assert kwargs["json"] == {"device_keys": {"@bob:example.org": []}}

    return aioresponses.CallbackResult(
        status=200,
        body=json.dumps(
            {
                "device_keys": {
                    "@bob:example.org": {
                        "VWXYZAB": {
                            "algorithms": [],
                            "device_id": "VWXYZAB",
                            "keys": {
                                "curve25519:HIJKLMN": "some+new+key",
                            },
                            "user_id": "@bob:example.org",
                        },
                    },
                },
            }
        ),
        headers={
            "Content-Type": "application/json",
        },
    )

mock_aioresponse.post(
    "https://matrix-client.example.org/_matrix/client/v3/keys/query",
    callback=callback,
)

assert await tracker.get_device_keys(
    ["@bob:example.org", "@carol:example.org"]
) == {
    "@bob:example.org": {
        "device_keys": {
            "VWXYZAB": {
                "algorithms": [],
                "device_id": "VWXYZAB",
                "keys": {
                    "curve25519:HIJKLMN": "some+new+key",
                },
                "user_id": "@bob:example.org",
            },
        },
    },
    "@carol:example.org": {
        "device_keys": {
            "OPQRSTU": {
                "algorithms": [],
                "device_id": "OPQRSTU",
                "keys": {
                    "curve25519:OPQRSTU": "some+other+key",
                },
                "user_id": "@carol:example.org",
            },
        },
    },
}

Edge cases

There are a couple of edge cases that we need to deal with. First of all, if we are querying the server for the device keys for a user, and concurrently, another call to get_device_keys needs to fetch the device keys for the same user, we don’t need to query the server for that user again. Instead, we can just use the result that we get from the first query. While we could just make a second request and not worry about efficiency, this raises the question of how this will affect our caching – if the two results differ, which result should be cached? This would also complicate handling for the second edge case. And, as it turns out, the record-keeping that we do to handle this edge case will also help with handling the second edge case. So while on the surface it may seem that it is simpler to just make a second request rather than synchronizing between calls to our function, it actually turns out to be simpler to do the synchronization.

The second edge case is that if we are querying the server for the device keys for a user, and concurrently, we get a sync response saying that the user’s devices have changed, we do not know if the result of our server query represents the user’s old devices or their new devices. In this case, we could retry the request. But what do we do if we get another sync response saying that the user’s devices have changed while the second request is in-flight? How many times will we keep retrying? If we don’t limit the number of retries, this could loop infinitely. Instead of retrying, for the sake of simplicity, in our implementation we will simply return the first result, but we will not cache the value so that the next time we request the user’s devices, we will re-query. Even though we may not be obtaining the most up-to-date value, we are still providing a correct response for the time that the request was made. Other implementations could use some sort of retry mechanism, but should take care to ensure that it will return after a reasonable amount of time, even if the server continuously indicates that devices have changed.

To handle the first edge case, when we make a server request to query a user’s device keys, we will record in our DeviceTracker object that we are querying that user’s keys. If we get another call to get_device_keys for that user’s keys, we will note that there is already a request in-flight, and we can wait for the first request to complete, and get the value from there. We will make use of Python’s asyncio.Future class for this, which represents a variable that will have a value in the future. Thus when we query the server for a user’s device keys, we will store a Future, and when we receive the result, we will resolve the Future to the user’s devices. We will store the Futures in the in_flight member variable, which will be a dict mapping user IDs to Futures.

DeviceTracker initialization:
self.in_flight: dict[str, asyncio.Future] = {}

We now write the code chunks in our get_device_keys function that we mentioned above would be used for dealing with the edge cases.

The first chunk occurs before we make our request to the server. It will look at the users that we were going to request from the server, and see if we have in-flight requests for those users. If so, it will record the Futures for those users, and drop them from our request. If not, it will mark those users as having requests in-flight, since we will be making the request.

check in-flight device key requests:
loop = asyncio.get_running_loop()

device_futures = {}

for user in users_needing_download.keys():
    # If we already have a request in-flight for the user, we can use that
    # result instead of re-requesting.  Otherwise, record that we will be
    # making a request for that user.
    if user in self.in_flight:
        device_futures[user] = self.in_flight[user]
    else:
        self.in_flight[user] = loop.create_future()

for user in device_futures.keys():
    # drop users from our request if we're using the result from a future
    del users_needing_download[user]

The second chunk occurs after we have made our request to the server. If the request raises an exception, we will set that exception on our Futures, so that the exception gets passed to anything that was waiting on our result.

clean up in-flight device key results:
e = typing.cast(BaseException, sys.exc_info()[1])
for user in users_needing_download.keys():
    self.in_flight[user].set_exception(e)
    del self.in_flight[user]

The third chunk sets the result of our Futures after a successful request. So whether the request fails or succeeds, the Future will resolve.

set result for in-flight device key request:
self.in_flight[user].set_result(user_device_info)
del self.in_flight[user]

And the last chunk gets the result for any Futures that we are waiting on. Note that we get the result for the Futures after we make our own request to the server. That is because we don’t want to wait for the Futures to resolve before we make our request; the requests can be made concurrently.

get in-flight device key results:
for user, future in device_futures.items():
    ret[user] = await future
Tests
test device tracking:
@pytest.mark.asyncio
async def test_concurrent_device_requests(mock_aioresponse):
    async with client.Client(
        storage={
            "access_token": "anaccesstoken",
            "user_id": "@alice:example.org",
            "device_id": "ABCDEFG",
            "device_tracker.tracked_users": {
                "@bob:example.org": True,
            },
        },
        callbacks={},
        base_client_url="https://matrix-client.example.org/_matrix/client/",
    ) as c:
        {{concurrent device requests test}}

To test that we can make concurrent requests, we will first make a request for a user’s device keys. We will process the request for the user’s keys using a custom handler, and make another request for that user’s device keys within that handler, to simulate a request being made concurrently. We will ensure that the device tracker does not make another request for the user’s device keys.

concurrent device requests test:
tracker = devices.DeviceTracker(c)

second_request_task = None

async def callback(url, **kwargs):
    assert kwargs["json"] == {"device_keys": {"@bob:example.org": []}}

    nonlocal second_request_task
    second_request_task = asyncio.create_task(
        tracker.get_device_keys(["@bob:example.org"])
    )

    # give the second request some time to execute to make sure that it blocks
    # on the future
    await asyncio.sleep(0.2)

    return aioresponses.CallbackResult(
        status=200,
        body=json.dumps(
            {
                "device_keys": {
                    "@bob:example.org": {
                        "HIJKLMN": {
                            "algorithms": [],
                            "device_id": "HIJKLMN",
                            "keys": {
                                "curve25519:HIJKLMN": "some+key",
                            },
                            "user_id": "@bob:example.org",
                        },
                    },
                },
            }
        ),
        headers={
            "Content-Type": "application/json",
        },
    )

mock_aioresponse.post(
    "https://matrix-client.example.org/_matrix/client/v3/keys/query",
    callback=callback,
)

assert await tracker.get_device_keys(["@bob:example.org"]) == {
    "@bob:example.org": {
        "device_keys": {
            "HIJKLMN": {
                "algorithms": [],
                "device_id": "HIJKLMN",
                "keys": {
                    "curve25519:HIJKLMN": "some+key",
                },
                "user_id": "@bob:example.org",
            },
        },
    },
}

assert await second_request_task == {
    "@bob:example.org": {
        "device_keys": {
            "HIJKLMN": {
                "algorithms": [],
                "device_id": "HIJKLMN",
                "keys": {
                    "curve25519:HIJKLMN": "some+key",
                },
                "user_id": "@bob:example.org",
            },
        },
    },
}

Now we will handle the second edge case, where a sync comes in while we’re in the process of querying the server for a user’s devices. As explained above, we will return the result that we obtain from the server, but we will not cache the result, since it could be outdated.

When we receive the DeviceChanges message from the sync loop, we will check whether the users that are in changed or left have an in-flight request, and if so, we will mark those users indicating that their results should not be cached.

DeviceTracker initialization:
self.do_not_cache: typing.Set[str] = set()
mark in-flight device key requests:
if user in self.in_flight:
    self.do_not_cache.add(user)

Then, in our _should_cache_result function, we check whether the user is marked for not caching, and if so, we return that they should not be cached after clearing the flag. Since we know that there is no other request for that user, it is safe to clear the flag.

check do not cache flag:
if user in self.do_not_cache:
    self.do_not_cache.remove(user)
    return False
Tests
test device tracking:
@pytest.mark.asyncio
async def test_update_during_device_request(mock_aioresponse):
    async with client.Client(
        storage={
            "access_token": "anaccesstoken",
            "user_id": "@alice:example.org",
            "device_id": "ABCDEFG",
            "device_tracker.tracked_users": {
                "@bob:example.org": True,
            },
        },
        callbacks={},
        base_client_url="https://matrix-client.example.org/_matrix/client/",
    ) as c:
        {{update during device request test}}

To test that we can make concurrent requests, we will first make a request for a user’s device keys. Again, we will process the request for the user’s keys using a custom handler, but this time in our handler, we will publish a DeviceChanges message to the tracker, simulating what would happen if a sync request came in. We then ensure that when we make another request, it will re-query the server.

update during device request test:
tracker = devices.DeviceTracker(c)

async def callback1(url, **kwargs):
    assert kwargs["json"] == {"device_keys": {"@bob:example.org": []}}

    await c.publisher.publish(client.DeviceChanges(["@bob:example.org"], []))

    return aioresponses.CallbackResult(
        status=200,
        body=json.dumps(
            {
                "device_keys": {
                    "@bob:example.org": {
                        "HIJKLMN": {
                            "algorithms": [],
                            "device_id": "HIJKLMN",
                            "keys": {
                                "curve25519:HIJKLMN": "some+key",
                            },
                            "user_id": "@bob:example.org",
                        },
                    },
                },
            }
        ),
        headers={
            "Content-Type": "application/json",
        },
    )

mock_aioresponse.post(
    "https://matrix-client.example.org/_matrix/client/v3/keys/query",
    callback=callback1,
)

def callback2(url, **kwargs):
    assert kwargs["json"] == {"device_keys": {"@bob:example.org": []}}

    return aioresponses.CallbackResult(
        status=200,
        body=json.dumps(
            {
                "device_keys": {
                    "@bob:example.org": {
                        "VWXYZAB": {
                            "algorithms": [],
                            "device_id": "VWXYZAB",
                            "keys": {
                                "curve25519:HIJKLMN": "some+new+key",
                            },
                            "user_id": "@bob:example.org",
                        },
                    },
                },
            }
        ),
        headers={
            "Content-Type": "application/json",
        },
    )

mock_aioresponse.post(
    "https://matrix-client.example.org/_matrix/client/v3/keys/query",
    callback=callback2,
)

assert await tracker.get_device_keys(["@bob:example.org"]) == {
    "@bob:example.org": {
        "device_keys": {
            "HIJKLMN": {
                "algorithms": [],
                "device_id": "HIJKLMN",
                "keys": {
                    "curve25519:HIJKLMN": "some+key",
                },
                "user_id": "@bob:example.org",
            },
        },
    },
}

assert await tracker.get_device_keys(["@bob:example.org"]) == {
    "@bob:example.org": {
        "device_keys": {
            "VWXYZAB": {
                "algorithms": [],
                "device_id": "VWXYZAB",
                "keys": {
                    "curve25519:HIJKLMN": "some+new+key",
                },
                "user_id": "@bob:example.org",
            },
        },
    },
}