Device Tracking¶
Now that we have uploaded our own keys to the server, we need to get the keys
for other devices. When we send a message, we must know the public keys for
all the devices that are in the room. We know what users are in the room,
because we can track the m.room.member
state
events. From there, we can call the POST /keys/query
endpoint to get users’ devices and their public keys.
However, we need to ensure that every time we send a message, our list of
device keys stays up to date, so that new devices will be able to decrypt the
message, and we won’t try to send keys to old devices. It would be inefficient
to have to query the device keys of every user in the room, every time we send
a message. Instead, the server will tell us, via the GET /sync
response, when a user’s devices have changed. We can then keep track
of the users who have device changes and only need to query the users who have
had changes.
The server will notify us when the devices change for a user that we share an
encrypted room with. This is done through the device_lists
property of the GET /sync
response. It is an object with two properties:
changed
, which lists users who have updated their device keys (including adding or removing devices), or who now share an encrypted room with the user since the last call toGET /sync
; andleft
, which lists the users who no longer share an encrypted room with the user since the last call toGET /sync
.
"device_lists": schema.Optional(
{
"changed": schema.Optional(list[str]),
"left": schema.Optional(list[str]),
}
),
When we encounter this in the sync, we will publish a message that contains the
contents of the device_lists
property.
if "device_lists" in body:
await self.publisher.publish(
DeviceChanges(
body["device_lists"].get("changed", []),
body["device_lists"].get("left", []),
)
)
class DeviceChanges(typing.NamedTuple):
"""A message indicating that user's devices have changed"""
changed: list[str]
left: list[str]
DeviceChanges.changed.__doc__ = "Users whose devices have changed"
DeviceChanges.left.__doc__ = "Users who no longer share a room"
We will now create a DeviceTracker
class that we can query for user’s device
keys. It will use the client’s storage to cache devices that we’ve previously
queried, and it will subscribe to the DeviceChanges
messages to know what
users it needs to re-fetch devices for.
For this class, we will start with a simple version, and then extend it to handle some edge cases, support querying by public keys, and handle cross-signing, which we will explain later on. Thus the code chunks presented in this section may include references to chunks defined in other sections. These can be safely ignored for now.
class DeviceTracker:
"""Tracks user devices to ensure that we're up to date"""
{{DeviceTracker class methods}}
def __init__(self, c: client.Client):
"""
Arguments:
``c``:
the client object
"""
self.client = c
c.subscribe(client.DeviceChanges, self._subscriber)
{{DeviceTracker initialization}}
We will also create a lock (also called a mutex) to ensure that concurrent calls do not conflict in accessing the shared data.
self.lock = asyncio.Lock()
We will keep track of the users that we are getting updates for (we add users
when they show up in the changed
list, and remove them when they show up in
the left
list). This will help us determine whether to cache results: if we
are not getting updates for a user, then we should not cache the results for
that user. Since we won’t be notified when their devices change, we should err
on the safe side and re-fetch their keys the next time we need them. As well,
when a user is listed in changed
or left
, we will clear their cached data
so that we will re-fetch their device keys the next time we need them.
async def _subscriber(self, changes: client.DeviceChanges) -> None:
async with self.lock:
tracked_users = self.client.storage.get(
"device_tracker.tracked_users",
{self.client.user_id: True}, # We always track our own devices
)
for user in changes.changed:
tracked_users[user] = True
self._delete_user_device_keys(user)
for user in changes.left:
if user in tracked_users:
del tracked_users[user]
self._delete_user_device_keys(user)
self.client.storage["device_tracker.tracked_users"] = tracked_users
def _delete_user_device_keys(self, user):
user_key = f"device_tracker.cache.{user}"
if user_key in self.client.storage:
del self.client.storage[user_key]
{{mark in-flight device key requests}}
(As mentioned above, the “mark in-flight device key requests” chunks are for dealing with some edge cases and will be discussed below. They can be ignored for now.)
Now we create a function to get the device keys for users. It will take an argument to force re-downloading all the keys, ignoring the cache. It will also take a timeout parameter which will limit the time that the homeserver will wait for a response from remote servers.
async def get_device_keys(
self,
users: typing.Iterable[str],
force_download=False,
timeout: typing.Optional[int] = None,
) -> dict[str, dict[str, dict]]:
"""Get the device keys for the given users.
Arguments:
``users``:
the user IDs to fetch device keys for
``force_download``:
whether to ignore the cache and force downloading of all device keys from
the server
``timeout``:
a timeout in milliseconds for the homeserver to wait for responses from
remote homeservers
Returns a dict mapping the user IDs to a dict with the following keys:
``device_keys``:
a dict mapping the user's device IDs to device keys
FIXME: add cross-signing keys
"""
ret = {}
users_needing_download: dict[str, list] = {}
async with self.lock:
{{load devices from cache}}
{{check in-flight device key requests}}
if users_needing_download != {}:
{{download and cache device keys}}
{{get in-flight device key results}}
return ret
Todo
validate device keys
make sure user_id and device_id properties match
if signed, make sure signature matches
make sure ed25519 key doesn’t change
(Again, the two chunks referring to “in-flight device key requests” are for handling some edge cases described below.)
If the argument to force downloading is set to False
, we will try to obtain
the device keys from the cache and keep track of the users that it doesn’t have
the keys for. Our users_needing_download
variable, which keeps track of the
users that we need to download keys for, will be formatted in the way that GET /keys/query
expects, to save as from converting it later on: it will be an
object mapping from the user ID to an array. The array is a list indicating
the devices that we want the keys for, or an empty array if we want all of a
user’s device keys. Since the latter is what we want, we set all values to the
empty array.
if not force_download:
for user in users:
storage_key = f"device_tracker.cache.{user}"
if storage_key in self.client.storage:
ret[user] = self.client.storage[storage_key]
else:
users_needing_download[user] = []
else:
users_needing_download = {user: [] for user in users}
We then download the device keys from the server, update our cache for the users that we are tracking, and then return the keys. The response from the server will be an object with several possible properties. The properties that we will be interested in here are:
device_keys
is a an object mapping from user ID to an object mapping device ID to device key.failures
is an object in which the keys are server names indicating that the remote server could not be contacted to provide the user’s keys. The values in the object do not matter. If a server is included infailures
, the users will not have any device information given. In this case, our function will return a result indicating that the users do not have any devices, but will not cache this result so that it will retry later.
The other properties in the response are related to cross-signing, which will be discussed later on.
Todo
also need to maintain map of algorithm -> pubkey -> (user, device)
req_body: dict[str, typing.Any] = {"device_keys": users_needing_download}
if timeout != None:
req_body["timeout"] = timeout
try:
resp = await self.client.authenticated(
self.client.http_session.post,
self.client.url("v3/keys/query"),
json=req_body,
)
status, resp_body = await client.check_response(resp)
schema.ensure_valid(
resp_body,
{
"device_keys": schema.Optional(dict[str, dict[str, dict]]),
"failures": schema.Optional(dict),
# FIXME: cross-signing keys
},
)
except:
{{clean up in-flight device key results}}
raise
user_device_keys = resp_body.get("device_keys", {})
failures = resp_body.get("failures", {})
tracked_users = self.client.storage.get("device_tracker.tracked_users", {})
async with self.lock:
for user in users_needing_download.keys():
user_device_info = {}
device_keys = user_device_keys.get(user, {})
user_device_info["device_keys"] = device_keys
# FIXME: process cross-signing keys
self._cache_result(user, tracked_users, failures, user_device_info)
ret[user] = user_device_info
{{set result for in-flight device key request}}
def _cache_result(self, user, tracked_users, failures, user_device_info):
if self._should_cache_result(user, tracked_users, failures):
self.client.storage[f"device_tracker.cache.{user}"] = user_device_info
(Again, the two chunks referring to “in-flight device key requests” are for handling some edge cases described below.)
We create a function that tells us whether the results should be cached. As mentioned above, we don’t cache results from users who aren’t being tracked, or if the user’s homeserver could not be contacted.
def _should_cache_result(
self, user: str, tracked_users: dict, failures: dict
) -> bool:
{{check do not cache flag}}
if user not in tracked_users:
return False
# find the user's server name
split_user_id = user.split(":", 1)
if len(split_user_id) != 2:
# invalid user ID, so don't cache
return False
if split_user_id[1] in failures:
return False
return True
(The “check do not cache flag” chunk is for handling edge cases described below.)
Tests
tests/test_device_tracking.py
:¶# {{copyright}}
import asyncio
import aioresponses
import json
import pytest
from matrixlib import client
from matrixlib import devices
{{test device tracking}}
To test the device tracker, we will pre-populate our device cache, and first ensure that when we try to get the device keys, we get the cached values.
@pytest.mark.asyncio
async def test_basic_device_tracking(mock_aioresponse):
async with client.Client(
storage={
"access_token": "anaccesstoken",
"user_id": "@alice:example.org",
"device_id": "ABCDEFG",
"device_tracker.tracked_users": {
"@bob:example.org": True,
"@carol:example.org": True,
},
"device_tracker.cache.@bob:example.org": {
"device_keys": {
"HIJKLMN": {
"algorithms": [],
"device_id": "HIJKLMN",
"keys": {
"curve25519:HIJKLMN": "some+key",
},
"user_id": "@bob:example.org",
},
},
},
"device_tracker.cache.@carol:example.org": {
"device_keys": {
"OPQRSTU": {
"algorithms": [],
"device_id": "OPQRSTU",
"keys": {
"curve25519:OPQRSTU": "some+other+key",
},
"user_id": "@carol:example.org",
},
},
},
},
callbacks={},
base_client_url="https://matrix-client.example.org/_matrix/client/",
) as c:
{{basic device tracking test}}
tracker = devices.DeviceTracker(c)
assert await tracker.get_device_keys(
["@bob:example.org", "@carol:example.org"]
) == {
"@bob:example.org": {
"device_keys": {
"HIJKLMN": {
"algorithms": [],
"device_id": "HIJKLMN",
"keys": {
"curve25519:HIJKLMN": "some+key",
},
"user_id": "@bob:example.org",
},
},
},
"@carol:example.org": {
"device_keys": {
"OPQRSTU": {
"algorithms": [],
"device_id": "OPQRSTU",
"keys": {
"curve25519:OPQRSTU": "some+other+key",
},
"user_id": "@carol:example.org",
},
},
},
}
We then create a sync response that will indicate that Bob’s devices have been updated.
mock_aioresponse.get(
"https://matrix-client.example.org/_matrix/client/v3/sync?timeout=30000",
status=200,
body=json.dumps(
{
"device_lists": {
"changed": ["@bob:example.org"],
},
"next_batch": "token1",
}
),
headers={
"content-type": "application/json",
},
)
mock_aioresponse.get(
"https://matrix-client.example.org/_matrix/client/v3/sync?since=token1&timeout=30000",
status=200,
body='{"next_batch":"token1"}',
headers={
"content-type": "application/json",
},
repeat=True,
)
def subscriber(msg) -> None:
c.stop_sync()
c.subscribe(client.DeviceChanges, subscriber)
c.subscribe(client.SyncFailed, subscriber)
c.start_sync()
try:
await c.sync_task
except asyncio.CancelledError:
pass
We now ensure that when we try to get the device keys, it refreshes Bob’s keys and returns the new keys. We use a callback in our HTTP request mock handler to ensure that the request body only requests Bob’s keys.
def callback(url, **kwargs):
assert kwargs["json"] == {"device_keys": {"@bob:example.org": []}}
return aioresponses.CallbackResult(
status=200,
body=json.dumps(
{
"device_keys": {
"@bob:example.org": {
"VWXYZAB": {
"algorithms": [],
"device_id": "VWXYZAB",
"keys": {
"curve25519:HIJKLMN": "some+new+key",
},
"user_id": "@bob:example.org",
},
},
},
}
),
headers={
"Content-Type": "application/json",
},
)
mock_aioresponse.post(
"https://matrix-client.example.org/_matrix/client/v3/keys/query",
callback=callback,
)
assert await tracker.get_device_keys(
["@bob:example.org", "@carol:example.org"]
) == {
"@bob:example.org": {
"device_keys": {
"VWXYZAB": {
"algorithms": [],
"device_id": "VWXYZAB",
"keys": {
"curve25519:HIJKLMN": "some+new+key",
},
"user_id": "@bob:example.org",
},
},
},
"@carol:example.org": {
"device_keys": {
"OPQRSTU": {
"algorithms": [],
"device_id": "OPQRSTU",
"keys": {
"curve25519:OPQRSTU": "some+other+key",
},
"user_id": "@carol:example.org",
},
},
},
}
Edge cases¶
There are a couple of edge cases that we need to deal with. First of all, if
we are querying the server for the device keys for a user, and concurrently,
another call to get_device_keys
needs to fetch the device keys for the same
user, we don’t need to query the server for that user again. Instead, we can
just use the result that we get from the first query. While we could just make
a second request and not worry about efficiency, this raises the question of
how this will affect our caching – if the two results differ, which result
should be cached? This would also complicate handling for the second edge
case. And, as it turns out, the record-keeping that we do to handle this edge
case will also help with handling the second edge case. So while on the
surface it may seem that it is simpler to just make a second request rather
than synchronizing between calls to our function, it actually turns out to be
simpler to do the synchronization.
The second edge case is that if we are querying the server for the device keys for a user, and concurrently, we get a sync response saying that the user’s devices have changed, we do not know if the result of our server query represents the user’s old devices or their new devices. In this case, we could retry the request. But what do we do if we get another sync response saying that the user’s devices have changed while the second request is in-flight? How many times will we keep retrying? If we don’t limit the number of retries, this could loop infinitely. Instead of retrying, for the sake of simplicity, in our implementation we will simply return the first result, but we will not cache the value so that the next time we request the user’s devices, we will re-query. Even though we may not be obtaining the most up-to-date value, we are still providing a correct response for the time that the request was made. Other implementations could use some sort of retry mechanism, but should take care to ensure that it will return after a reasonable amount of time, even if the server continuously indicates that devices have changed.
To handle the first edge case, when we make a server request to query a user’s
device keys, we will record in our DeviceTracker
object that we are querying
that user’s keys. If we get another call to get_device_keys
for that user’s
keys, we will note that there is already a request in-flight, and we can wait
for the first request to complete, and get the value from there. We will make
use of Python’s asyncio.Future
class for this, which
represents a variable that will have a value in the future. Thus when we query
the server for a user’s device keys, we will store a Future
, and when we
receive the result, we will resolve the Future
to the user’s devices. We will
store the Future
s in the in_flight
member variable, which will be a dict
mapping user IDs to Future
s.
self.in_flight: dict[str, asyncio.Future] = {}
We now write the code chunks in our get_device_keys
function that we
mentioned above would be used for dealing with the edge cases.
The first chunk occurs before we make our request to the server. It will look
at the users that we were going to request from the server, and see if we have
in-flight requests for those users. If so, it will record the Future
s for
those users, and drop them from our request. If not, it will mark those users
as having requests in-flight, since we will be making the request.
loop = asyncio.get_running_loop()
device_futures = {}
for user in users_needing_download.keys():
# If we already have a request in-flight for the user, we can use that
# result instead of re-requesting. Otherwise, record that we will be
# making a request for that user.
if user in self.in_flight:
device_futures[user] = self.in_flight[user]
else:
self.in_flight[user] = loop.create_future()
for user in device_futures.keys():
# drop users from our request if we're using the result from a future
del users_needing_download[user]
The second chunk occurs after we have made our request to the server. If the
request raises an exception, we will set that exception on our Future
s, so
that the exception gets passed to anything that was waiting on our result.
e = typing.cast(BaseException, sys.exc_info()[1])
for user in users_needing_download.keys():
self.in_flight[user].set_exception(e)
del self.in_flight[user]
The third chunk sets the result of our Future
s after a successful request.
So whether the request fails or succeeds, the Future
will resolve.
self.in_flight[user].set_result(user_device_info)
del self.in_flight[user]
And the last chunk gets the result for any Future
s that we are waiting on.
Note that we get the result for the Future
s after we make our own request to
the server. That is because we don’t want to wait for the Future
s to resolve
before we make our request; the requests can be made concurrently.
for user, future in device_futures.items():
ret[user] = await future
Tests
@pytest.mark.asyncio
async def test_concurrent_device_requests(mock_aioresponse):
async with client.Client(
storage={
"access_token": "anaccesstoken",
"user_id": "@alice:example.org",
"device_id": "ABCDEFG",
"device_tracker.tracked_users": {
"@bob:example.org": True,
},
},
callbacks={},
base_client_url="https://matrix-client.example.org/_matrix/client/",
) as c:
{{concurrent device requests test}}
To test that we can make concurrent requests, we will first make a request for a user’s device keys. We will process the request for the user’s keys using a custom handler, and make another request for that user’s device keys within that handler, to simulate a request being made concurrently. We will ensure that the device tracker does not make another request for the user’s device keys.
tracker = devices.DeviceTracker(c)
second_request_task = None
async def callback(url, **kwargs):
assert kwargs["json"] == {"device_keys": {"@bob:example.org": []}}
nonlocal second_request_task
second_request_task = asyncio.create_task(
tracker.get_device_keys(["@bob:example.org"])
)
# give the second request some time to execute to make sure that it blocks
# on the future
await asyncio.sleep(0.2)
return aioresponses.CallbackResult(
status=200,
body=json.dumps(
{
"device_keys": {
"@bob:example.org": {
"HIJKLMN": {
"algorithms": [],
"device_id": "HIJKLMN",
"keys": {
"curve25519:HIJKLMN": "some+key",
},
"user_id": "@bob:example.org",
},
},
},
}
),
headers={
"Content-Type": "application/json",
},
)
mock_aioresponse.post(
"https://matrix-client.example.org/_matrix/client/v3/keys/query",
callback=callback,
)
assert await tracker.get_device_keys(["@bob:example.org"]) == {
"@bob:example.org": {
"device_keys": {
"HIJKLMN": {
"algorithms": [],
"device_id": "HIJKLMN",
"keys": {
"curve25519:HIJKLMN": "some+key",
},
"user_id": "@bob:example.org",
},
},
},
}
assert await second_request_task == {
"@bob:example.org": {
"device_keys": {
"HIJKLMN": {
"algorithms": [],
"device_id": "HIJKLMN",
"keys": {
"curve25519:HIJKLMN": "some+key",
},
"user_id": "@bob:example.org",
},
},
},
}
Now we will handle the second edge case, where a sync comes in while we’re in the process of querying the server for a user’s devices. As explained above, we will return the result that we obtain from the server, but we will not cache the result, since it could be outdated.
When we receive the DeviceChanges
message from the sync loop, we will check
whether the users that are in changed
or left
have an in-flight request,
and if so, we will mark those users indicating that their results should not be
cached.
self.do_not_cache: typing.Set[str] = set()
if user in self.in_flight:
self.do_not_cache.add(user)
Then, in our _should_cache_result
function, we check whether the user is
marked for not caching, and if so, we return that they should not be cached
after clearing the flag. Since we know that there is no other request for that
user, it is safe to clear the flag.
if user in self.do_not_cache:
self.do_not_cache.remove(user)
return False
Tests
@pytest.mark.asyncio
async def test_update_during_device_request(mock_aioresponse):
async with client.Client(
storage={
"access_token": "anaccesstoken",
"user_id": "@alice:example.org",
"device_id": "ABCDEFG",
"device_tracker.tracked_users": {
"@bob:example.org": True,
},
},
callbacks={},
base_client_url="https://matrix-client.example.org/_matrix/client/",
) as c:
{{update during device request test}}
To test that we can make concurrent requests, we will first make a request for
a user’s device keys. Again, we will process the request for the user’s keys
using a custom handler, but this time in our handler, we will publish a
DeviceChanges
message to the tracker, simulating what would happen if a sync
request came in. We then ensure that when we make another request, it will
re-query the server.
tracker = devices.DeviceTracker(c)
async def callback1(url, **kwargs):
assert kwargs["json"] == {"device_keys": {"@bob:example.org": []}}
await c.publisher.publish(client.DeviceChanges(["@bob:example.org"], []))
return aioresponses.CallbackResult(
status=200,
body=json.dumps(
{
"device_keys": {
"@bob:example.org": {
"HIJKLMN": {
"algorithms": [],
"device_id": "HIJKLMN",
"keys": {
"curve25519:HIJKLMN": "some+key",
},
"user_id": "@bob:example.org",
},
},
},
}
),
headers={
"Content-Type": "application/json",
},
)
mock_aioresponse.post(
"https://matrix-client.example.org/_matrix/client/v3/keys/query",
callback=callback1,
)
def callback2(url, **kwargs):
assert kwargs["json"] == {"device_keys": {"@bob:example.org": []}}
return aioresponses.CallbackResult(
status=200,
body=json.dumps(
{
"device_keys": {
"@bob:example.org": {
"VWXYZAB": {
"algorithms": [],
"device_id": "VWXYZAB",
"keys": {
"curve25519:HIJKLMN": "some+new+key",
},
"user_id": "@bob:example.org",
},
},
},
}
),
headers={
"Content-Type": "application/json",
},
)
mock_aioresponse.post(
"https://matrix-client.example.org/_matrix/client/v3/keys/query",
callback=callback2,
)
assert await tracker.get_device_keys(["@bob:example.org"]) == {
"@bob:example.org": {
"device_keys": {
"HIJKLMN": {
"algorithms": [],
"device_id": "HIJKLMN",
"keys": {
"curve25519:HIJKLMN": "some+key",
},
"user_id": "@bob:example.org",
},
},
},
}
assert await tracker.get_device_keys(["@bob:example.org"]) == {
"@bob:example.org": {
"device_keys": {
"VWXYZAB": {
"algorithms": [],
"device_id": "VWXYZAB",
"keys": {
"curve25519:HIJKLMN": "some+new+key",
},
"user_id": "@bob:example.org",
},
},
},
}