mas_storage/oauth2/session.rs
1// Copyright 2024 New Vector Ltd.
2// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7use std::net::IpAddr;
8
9use async_trait::async_trait;
10use chrono::{DateTime, Utc};
11use mas_data_model::{BrowserSession, Client, Device, Session, User};
12use oauth2_types::scope::Scope;
13use rand_core::RngCore;
14use ulid::Ulid;
15
16use crate::{Clock, Pagination, pagination::Page, repository_impl};
17
18#[derive(Clone, Copy, Debug, PartialEq, Eq)]
19pub enum OAuth2SessionState {
20    Active,
21    Finished,
22}
23
24impl OAuth2SessionState {
25    pub fn is_active(self) -> bool {
26        matches!(self, Self::Active)
27    }
28
29    pub fn is_finished(self) -> bool {
30        matches!(self, Self::Finished)
31    }
32}
33
34#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
35pub enum ClientKind {
36    Static,
37    Dynamic,
38}
39
40impl ClientKind {
41    pub fn is_static(self) -> bool {
42        matches!(self, Self::Static)
43    }
44}
45
46/// Filter parameters for listing OAuth 2.0 sessions
47#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
48pub struct OAuth2SessionFilter<'a> {
49    user: Option<&'a User>,
50    any_user: Option<bool>,
51    browser_session: Option<&'a BrowserSession>,
52    device: Option<&'a Device>,
53    client: Option<&'a Client>,
54    client_kind: Option<ClientKind>,
55    state: Option<OAuth2SessionState>,
56    scope: Option<&'a Scope>,
57    last_active_before: Option<DateTime<Utc>>,
58    last_active_after: Option<DateTime<Utc>>,
59}
60
61impl<'a> OAuth2SessionFilter<'a> {
62    /// Create a new [`OAuth2SessionFilter`] with default values
63    #[must_use]
64    pub fn new() -> Self {
65        Self::default()
66    }
67
68    /// List sessions for a specific user
69    #[must_use]
70    pub fn for_user(mut self, user: &'a User) -> Self {
71        self.user = Some(user);
72        self
73    }
74
75    /// Get the user filter
76    ///
77    /// Returns [`None`] if no user filter was set
78    #[must_use]
79    pub fn user(&self) -> Option<&'a User> {
80        self.user
81    }
82
83    /// List sessions which belong to any user
84    #[must_use]
85    pub fn for_any_user(mut self) -> Self {
86        self.any_user = Some(true);
87        self
88    }
89
90    /// List sessions which belong to no user
91    #[must_use]
92    pub fn for_no_user(mut self) -> Self {
93        self.any_user = Some(false);
94        self
95    }
96
97    /// Get the 'any user' filter
98    ///
99    /// Returns [`None`] if no 'any user' filter was set
100    #[must_use]
101    pub fn any_user(&self) -> Option<bool> {
102        self.any_user
103    }
104
105    /// List sessions started by a specific browser session
106    #[must_use]
107    pub fn for_browser_session(mut self, browser_session: &'a BrowserSession) -> Self {
108        self.browser_session = Some(browser_session);
109        self
110    }
111
112    /// Get the browser session filter
113    ///
114    /// Returns [`None`] if no browser session filter was set
115    #[must_use]
116    pub fn browser_session(&self) -> Option<&'a BrowserSession> {
117        self.browser_session
118    }
119
120    /// List sessions for a specific client
121    #[must_use]
122    pub fn for_client(mut self, client: &'a Client) -> Self {
123        self.client = Some(client);
124        self
125    }
126
127    /// Get the client filter
128    ///
129    /// Returns [`None`] if no client filter was set
130    #[must_use]
131    pub fn client(&self) -> Option<&'a Client> {
132        self.client
133    }
134
135    /// List only static clients
136    #[must_use]
137    pub fn only_static_clients(mut self) -> Self {
138        self.client_kind = Some(ClientKind::Static);
139        self
140    }
141
142    /// List only dynamic clients
143    #[must_use]
144    pub fn only_dynamic_clients(mut self) -> Self {
145        self.client_kind = Some(ClientKind::Dynamic);
146        self
147    }
148
149    /// Get the client kind filter
150    ///
151    /// Returns [`None`] if no client kind filter was set
152    #[must_use]
153    pub fn client_kind(&self) -> Option<ClientKind> {
154        self.client_kind
155    }
156
157    /// Only return sessions with a last active time before the given time
158    #[must_use]
159    pub fn with_last_active_before(mut self, last_active_before: DateTime<Utc>) -> Self {
160        self.last_active_before = Some(last_active_before);
161        self
162    }
163
164    /// Only return sessions with a last active time after the given time
165    #[must_use]
166    pub fn with_last_active_after(mut self, last_active_after: DateTime<Utc>) -> Self {
167        self.last_active_after = Some(last_active_after);
168        self
169    }
170
171    /// Get the last active before filter
172    ///
173    /// Returns [`None`] if no client filter was set
174    #[must_use]
175    pub fn last_active_before(&self) -> Option<DateTime<Utc>> {
176        self.last_active_before
177    }
178
179    /// Get the last active after filter
180    ///
181    /// Returns [`None`] if no client filter was set
182    #[must_use]
183    pub fn last_active_after(&self) -> Option<DateTime<Utc>> {
184        self.last_active_after
185    }
186
187    /// Only return active sessions
188    #[must_use]
189    pub fn active_only(mut self) -> Self {
190        self.state = Some(OAuth2SessionState::Active);
191        self
192    }
193
194    /// Only return finished sessions
195    #[must_use]
196    pub fn finished_only(mut self) -> Self {
197        self.state = Some(OAuth2SessionState::Finished);
198        self
199    }
200
201    /// Get the state filter
202    ///
203    /// Returns [`None`] if no state filter was set
204    #[must_use]
205    pub fn state(&self) -> Option<OAuth2SessionState> {
206        self.state
207    }
208
209    /// Only return sessions with the given scope
210    #[must_use]
211    pub fn with_scope(mut self, scope: &'a Scope) -> Self {
212        self.scope = Some(scope);
213        self
214    }
215
216    /// Get the scope filter
217    ///
218    /// Returns [`None`] if no scope filter was set
219    #[must_use]
220    pub fn scope(&self) -> Option<&'a Scope> {
221        self.scope
222    }
223
224    /// Only return sessions that have the given device in their scope
225    #[must_use]
226    pub fn for_device(mut self, device: &'a Device) -> Self {
227        self.device = Some(device);
228        self
229    }
230
231    /// Get the device filter
232    ///
233    /// Returns [`None`] if no device filter was set
234    #[must_use]
235    pub fn device(&self) -> Option<&'a Device> {
236        self.device
237    }
238}
239
240/// An [`OAuth2SessionRepository`] helps interacting with [`Session`]
241/// saved in the storage backend
242#[async_trait]
243pub trait OAuth2SessionRepository: Send + Sync {
244    /// The error type returned by the repository
245    type Error;
246
247    /// Lookup an [`Session`] by its ID
248    ///
249    /// Returns `None` if no [`Session`] was found
250    ///
251    /// # Parameters
252    ///
253    /// * `id`: The ID of the [`Session`] to lookup
254    ///
255    /// # Errors
256    ///
257    /// Returns [`Self::Error`] if the underlying repository fails
258    async fn lookup(&mut self, id: Ulid) -> Result<Option<Session>, Self::Error>;
259
260    /// Create a new [`Session`] with the given parameters
261    ///
262    /// Returns the newly created [`Session`]
263    ///
264    /// # Parameters
265    ///
266    /// * `rng`: The random number generator to use
267    /// * `clock`: The clock used to generate timestamps
268    /// * `client`: The [`Client`] which created the [`Session`]
269    /// * `user`: The [`User`] for which the session should be created, if any
270    /// * `user_session`: The [`BrowserSession`] of the user which completed the
271    ///   authorization, if any
272    /// * `scope`: The [`Scope`] of the [`Session`]
273    ///
274    /// # Errors
275    ///
276    /// Returns [`Self::Error`] if the underlying repository fails
277    async fn add(
278        &mut self,
279        rng: &mut (dyn RngCore + Send),
280        clock: &dyn Clock,
281        client: &Client,
282        user: Option<&User>,
283        user_session: Option<&BrowserSession>,
284        scope: Scope,
285    ) -> Result<Session, Self::Error>;
286
287    /// Create a new [`Session`] out of a [`Client`] and a [`BrowserSession`]
288    ///
289    /// Returns the newly created [`Session`]
290    ///
291    /// # Parameters
292    ///
293    /// * `rng`: The random number generator to use
294    /// * `clock`: The clock used to generate timestamps
295    /// * `client`: The [`Client`] which created the [`Session`]
296    /// * `user_session`: The [`BrowserSession`] of the user which completed the
297    ///   authorization
298    /// * `scope`: The [`Scope`] of the [`Session`]
299    ///
300    /// # Errors
301    ///
302    /// Returns [`Self::Error`] if the underlying repository fails
303    async fn add_from_browser_session(
304        &mut self,
305        rng: &mut (dyn RngCore + Send),
306        clock: &dyn Clock,
307        client: &Client,
308        user_session: &BrowserSession,
309        scope: Scope,
310    ) -> Result<Session, Self::Error> {
311        self.add(
312            rng,
313            clock,
314            client,
315            Some(&user_session.user),
316            Some(user_session),
317            scope,
318        )
319        .await
320    }
321
322    /// Create a new [`Session`] for a [`Client`] using the client credentials
323    /// flow
324    ///
325    /// Returns the newly created [`Session`]
326    ///
327    /// # Parameters
328    ///
329    /// * `rng`: The random number generator to use
330    /// * `clock`: The clock used to generate timestamps
331    /// * `client`: The [`Client`] which created the [`Session`]
332    /// * `scope`: The [`Scope`] of the [`Session`]
333    ///
334    /// # Errors
335    ///
336    /// Returns [`Self::Error`] if the underlying repository fails
337    async fn add_from_client_credentials(
338        &mut self,
339        rng: &mut (dyn RngCore + Send),
340        clock: &dyn Clock,
341        client: &Client,
342        scope: Scope,
343    ) -> Result<Session, Self::Error> {
344        self.add(rng, clock, client, None, None, scope).await
345    }
346
347    /// Mark a [`Session`] as finished
348    ///
349    /// Returns the updated [`Session`]
350    ///
351    /// # Parameters
352    ///
353    /// * `clock`: The clock used to generate timestamps
354    /// * `session`: The [`Session`] to mark as finished
355    ///
356    /// # Errors
357    ///
358    /// Returns [`Self::Error`] if the underlying repository fails
359    async fn finish(&mut self, clock: &dyn Clock, session: Session)
360    -> Result<Session, Self::Error>;
361
362    /// Mark all the [`Session`] matching the given filter as finished
363    ///
364    /// Returns the number of sessions affected
365    ///
366    /// # Parameters
367    ///
368    /// * `clock`: The clock used to generate timestamps
369    /// * `filter`: The filter parameters
370    ///
371    /// # Errors
372    ///
373    /// Returns [`Self::Error`] if the underlying repository fails
374    async fn finish_bulk(
375        &mut self,
376        clock: &dyn Clock,
377        filter: OAuth2SessionFilter<'_>,
378    ) -> Result<usize, Self::Error>;
379
380    /// List [`Session`]s matching the given filter and pagination parameters
381    ///
382    /// # Parameters
383    ///
384    /// * `filter`: The filter parameters
385    /// * `pagination`: The pagination parameters
386    ///
387    /// # Errors
388    ///
389    /// Returns [`Self::Error`] if the underlying repository fails
390    async fn list(
391        &mut self,
392        filter: OAuth2SessionFilter<'_>,
393        pagination: Pagination,
394    ) -> Result<Page<Session>, Self::Error>;
395
396    /// Count [`Session`]s matching the given filter
397    ///
398    /// # Parameters
399    ///
400    /// * `filter`: The filter parameters
401    ///
402    /// # Errors
403    ///
404    /// Returns [`Self::Error`] if the underlying repository fails
405    async fn count(&mut self, filter: OAuth2SessionFilter<'_>) -> Result<usize, Self::Error>;
406
407    /// Record a batch of [`Session`] activity
408    ///
409    /// # Parameters
410    ///
411    /// * `activity`: A list of tuples containing the session ID, the last
412    ///   activity timestamp and the IP address of the client
413    ///
414    /// # Errors
415    ///
416    /// Returns [`Self::Error`] if the underlying repository fails
417    async fn record_batch_activity(
418        &mut self,
419        activity: Vec<(Ulid, DateTime<Utc>, Option<IpAddr>)>,
420    ) -> Result<(), Self::Error>;
421
422    /// Record the user agent of a [`Session`]
423    ///
424    /// # Parameters
425    ///
426    /// * `session`: The [`Session`] to record the user agent for
427    /// * `user_agent`: The user agent to record
428    async fn record_user_agent(
429        &mut self,
430        session: Session,
431        user_agent: String,
432    ) -> Result<Session, Self::Error>;
433
434    /// Set the human name of a [`Session`]
435    ///
436    /// # Parameters
437    ///
438    /// * `session`: The [`Session`] to set the human name for
439    /// * `human_name`: The human name to set
440    async fn set_human_name(
441        &mut self,
442        session: Session,
443        human_name: Option<String>,
444    ) -> Result<Session, Self::Error>;
445}
446
447repository_impl!(OAuth2SessionRepository:
448    async fn lookup(&mut self, id: Ulid) -> Result<Option<Session>, Self::Error>;
449
450    async fn add(
451        &mut self,
452        rng: &mut (dyn RngCore + Send),
453        clock: &dyn Clock,
454        client: &Client,
455        user: Option<&User>,
456        user_session: Option<&BrowserSession>,
457        scope: Scope,
458    ) -> Result<Session, Self::Error>;
459
460    async fn add_from_browser_session(
461        &mut self,
462        rng: &mut (dyn RngCore + Send),
463        clock: &dyn Clock,
464        client: &Client,
465        user_session: &BrowserSession,
466        scope: Scope,
467    ) -> Result<Session, Self::Error>;
468
469    async fn add_from_client_credentials(
470        &mut self,
471        rng: &mut (dyn RngCore + Send),
472        clock: &dyn Clock,
473        client: &Client,
474        scope: Scope,
475    ) -> Result<Session, Self::Error>;
476
477    async fn finish(&mut self, clock: &dyn Clock, session: Session)
478        -> Result<Session, Self::Error>;
479
480    async fn finish_bulk(
481        &mut self,
482        clock: &dyn Clock,
483        filter: OAuth2SessionFilter<'_>,
484    ) -> Result<usize, Self::Error>;
485
486    async fn list(
487        &mut self,
488        filter: OAuth2SessionFilter<'_>,
489        pagination: Pagination,
490    ) -> Result<Page<Session>, Self::Error>;
491
492    async fn count(&mut self, filter: OAuth2SessionFilter<'_>) -> Result<usize, Self::Error>;
493
494    async fn record_batch_activity(
495        &mut self,
496        activity: Vec<(Ulid, DateTime<Utc>, Option<IpAddr>)>,
497    ) -> Result<(), Self::Error>;
498
499    async fn record_user_agent(
500        &mut self,
501        session: Session,
502        user_agent: String,
503    ) -> Result<Session, Self::Error>;
504
505    async fn set_human_name(
506        &mut self,
507        session: Session,
508        human_name: Option<String>,
509    ) -> Result<Session, Self::Error>;
510);