Skip to content

Commit 45d4218

Browse files
committed
Registration step to set a display name
1 parent 336cfd7 commit 45d4218

File tree

11 files changed

+348
-4
lines changed

11 files changed

+348
-4
lines changed

crates/handlers/src/lib.rs

+5
Original file line numberDiff line numberDiff line change
@@ -383,6 +383,11 @@ where
383383
get(self::views::register::steps::verify_email::get)
384384
.post(self::views::register::steps::verify_email::post),
385385
)
386+
.route(
387+
mas_router::RegisterDisplayName::route(),
388+
get(self::views::register::steps::display_name::get)
389+
.post(self::views::register::steps::display_name::post),
390+
)
386391
.route(
387392
mas_router::RegisterFinish::route(),
388393
get(self::views::register::steps::finish::get),
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
// Copyright 2025 New Vector Ltd.
2+
//
3+
// SPDX-License-Identifier: AGPL-3.0-only
4+
// Please see LICENSE in the repository root for full details.
5+
6+
use anyhow::Context as _;
7+
use axum::{
8+
extract::{Path, State},
9+
response::{Html, IntoResponse, Response},
10+
Form,
11+
};
12+
use mas_axum_utils::{
13+
cookies::CookieJar,
14+
csrf::{CsrfExt as _, ProtectedForm},
15+
FancyError,
16+
};
17+
use mas_router::{PostAuthAction, UrlBuilder};
18+
use mas_storage::{BoxClock, BoxRepository, BoxRng};
19+
use mas_templates::{
20+
FieldError, RegisterStepsDisplayNameContext, RegisterStepsDisplayNameFormField,
21+
TemplateContext as _, Templates, ToFormState,
22+
};
23+
use serde::{Deserialize, Serialize};
24+
use ulid::Ulid;
25+
26+
use crate::{views::shared::OptionalPostAuthAction, PreferredLanguage};
27+
28+
#[derive(Deserialize, Default)]
29+
#[serde(rename_all = "snake_case")]
30+
enum FormAction {
31+
#[default]
32+
Set,
33+
Skip,
34+
}
35+
36+
#[derive(Deserialize, Serialize)]
37+
pub(crate) struct DisplayNameForm {
38+
#[serde(skip_serializing, default)]
39+
action: FormAction,
40+
#[serde(default)]
41+
display_name: String,
42+
}
43+
44+
impl ToFormState for DisplayNameForm {
45+
type Field = mas_templates::RegisterStepsDisplayNameFormField;
46+
}
47+
48+
#[tracing::instrument(
49+
name = "handlers.views.register.steps.display_name.get",
50+
fields(user_registration.id = %id),
51+
skip_all,
52+
err,
53+
)]
54+
pub(crate) async fn get(
55+
mut rng: BoxRng,
56+
clock: BoxClock,
57+
PreferredLanguage(locale): PreferredLanguage,
58+
State(templates): State<Templates>,
59+
State(url_builder): State<UrlBuilder>,
60+
mut repo: BoxRepository,
61+
Path(id): Path<Ulid>,
62+
cookie_jar: CookieJar,
63+
) -> Result<Response, FancyError> {
64+
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
65+
66+
let registration = repo
67+
.user_registration()
68+
.lookup(id)
69+
.await?
70+
.context("Could not find user registration")?;
71+
72+
// If the registration is completed, we can go to the registration destination
73+
// XXX: this might not be the right thing to do? Maybe an error page would be
74+
// better?
75+
if registration.completed_at.is_some() {
76+
let post_auth_action: Option<PostAuthAction> = registration
77+
.post_auth_action
78+
.map(serde_json::from_value)
79+
.transpose()?;
80+
81+
return Ok((
82+
cookie_jar,
83+
OptionalPostAuthAction::from(post_auth_action)
84+
.go_next(&url_builder)
85+
.into_response(),
86+
)
87+
.into_response());
88+
}
89+
90+
let ctx = RegisterStepsDisplayNameContext::new()
91+
.with_csrf(csrf_token.form_value())
92+
.with_language(locale);
93+
94+
let content = templates.render_register_steps_display_name(&ctx)?;
95+
96+
Ok((cookie_jar, Html(content)).into_response())
97+
}
98+
99+
#[tracing::instrument(
100+
name = "handlers.views.register.steps.display_name.post",
101+
fields(user_registration.id = %id),
102+
skip_all,
103+
err,
104+
)]
105+
pub(crate) async fn post(
106+
mut rng: BoxRng,
107+
clock: BoxClock,
108+
PreferredLanguage(locale): PreferredLanguage,
109+
State(templates): State<Templates>,
110+
State(url_builder): State<UrlBuilder>,
111+
mut repo: BoxRepository,
112+
Path(id): Path<Ulid>,
113+
cookie_jar: CookieJar,
114+
Form(form): Form<ProtectedForm<DisplayNameForm>>,
115+
) -> Result<Response, FancyError> {
116+
let registration = repo
117+
.user_registration()
118+
.lookup(id)
119+
.await?
120+
.context("Could not find user registration")?;
121+
122+
// If the registration is completed, we can go to the registration destination
123+
// XXX: this might not be the right thing to do? Maybe an error page would be
124+
// better?
125+
if registration.completed_at.is_some() {
126+
let post_auth_action: Option<PostAuthAction> = registration
127+
.post_auth_action
128+
.map(serde_json::from_value)
129+
.transpose()?;
130+
131+
return Ok((
132+
cookie_jar,
133+
OptionalPostAuthAction::from(post_auth_action)
134+
.go_next(&url_builder)
135+
.into_response(),
136+
)
137+
.into_response());
138+
}
139+
140+
let form = cookie_jar.verify_form(&clock, form)?;
141+
142+
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
143+
144+
let display_name = match form.action {
145+
FormAction::Set => {
146+
let display_name = form.display_name.trim();
147+
148+
if display_name.is_empty() || display_name.len() > 255 {
149+
let ctx = RegisterStepsDisplayNameContext::new()
150+
.with_form_state(form.to_form_state().with_error_on_field(
151+
RegisterStepsDisplayNameFormField::DisplayName,
152+
FieldError::Invalid,
153+
))
154+
.with_csrf(csrf_token.form_value())
155+
.with_language(locale);
156+
157+
return Ok((
158+
cookie_jar,
159+
Html(templates.render_register_steps_display_name(&ctx)?),
160+
)
161+
.into_response());
162+
}
163+
164+
display_name.to_owned()
165+
}
166+
FormAction::Skip => {
167+
// If the user chose to skip, we do the same as Synapse and use the localpart as
168+
// default display name
169+
registration.username.clone()
170+
}
171+
};
172+
173+
let registration = repo
174+
.user_registration()
175+
.set_display_name(registration, display_name)
176+
.await?;
177+
178+
repo.save().await?;
179+
180+
let destination = mas_router::RegisterFinish::new(registration.id);
181+
return Ok((cookie_jar, url_builder.redirect(&destination)).into_response());
182+
}

crates/handlers/src/views/register/steps/finish.rs

+8
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,14 @@ pub(crate) async fn get(
102102
)));
103103
}
104104

105+
// Check that the display name is set
106+
if registration.display_name.is_none() {
107+
return Ok((
108+
cookie_jar,
109+
url_builder.redirect(&mas_router::RegisterDisplayName::new(registration.id)),
110+
));
111+
}
112+
105113
// Everuthing is good, let's complete the registration
106114
let registration = repo
107115
.user_registration()

crates/handlers/src/views/register/steps/mod.rs

+1
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,6 @@
33
// SPDX-License-Identifier: AGPL-3.0-only
44
// Please see LICENSE in the repository root for full details.
55

6+
pub(crate) mod display_name;
67
pub(crate) mod finish;
78
pub(crate) mod verify_email;

crates/router/src/endpoints.rs

+24
Original file line numberDiff line numberDiff line change
@@ -444,6 +444,30 @@ impl From<Option<PostAuthAction>> for PasswordRegister {
444444
}
445445
}
446446

447+
/// `GET|POST /register/steps/:id/display-name`
448+
#[derive(Debug, Clone)]
449+
pub struct RegisterDisplayName {
450+
id: Ulid,
451+
}
452+
453+
impl RegisterDisplayName {
454+
#[must_use]
455+
pub fn new(id: Ulid) -> Self {
456+
Self { id }
457+
}
458+
}
459+
460+
impl Route for RegisterDisplayName {
461+
type Query = ();
462+
fn route() -> &'static str {
463+
"/register/steps/:id/display-name"
464+
}
465+
466+
fn path(&self) -> std::borrow::Cow<'static, str> {
467+
format!("/register/steps/{}/display-name", self.id).into()
468+
}
469+
}
470+
447471
/// `GET|POST /register/steps/:id/verify-email`
448472
#[derive(Debug, Clone)]
449473
pub struct RegisterVerifyEmail {

crates/templates/src/context.rs

+51
Original file line numberDiff line numberDiff line change
@@ -1000,6 +1000,57 @@ impl TemplateContext for RegisterStepsVerifyEmailContext {
10001000
}
10011001
}
10021002

1003+
/// Fields for the display name form
1004+
#[derive(Serialize, Deserialize, Debug, Clone, Copy, Hash, PartialEq, Eq)]
1005+
#[serde(rename_all = "snake_case")]
1006+
pub enum RegisterStepsDisplayNameFormField {
1007+
/// The display name
1008+
DisplayName,
1009+
}
1010+
1011+
impl FormField for RegisterStepsDisplayNameFormField {
1012+
fn keep(&self) -> bool {
1013+
match self {
1014+
Self::DisplayName => true,
1015+
}
1016+
}
1017+
}
1018+
1019+
/// Context used by the `display_name.html` template
1020+
#[derive(Serialize, Default)]
1021+
pub struct RegisterStepsDisplayNameContext {
1022+
form: FormState<RegisterStepsDisplayNameFormField>,
1023+
}
1024+
1025+
impl RegisterStepsDisplayNameContext {
1026+
/// Constructs a context for the display name page
1027+
#[must_use]
1028+
pub fn new() -> Self {
1029+
Self::default()
1030+
}
1031+
1032+
/// Set the form state
1033+
#[must_use]
1034+
pub fn with_form_state(
1035+
mut self,
1036+
form_state: FormState<RegisterStepsDisplayNameFormField>,
1037+
) -> Self {
1038+
self.form = form_state;
1039+
self
1040+
}
1041+
}
1042+
1043+
impl TemplateContext for RegisterStepsDisplayNameContext {
1044+
fn sample(_now: chrono::DateTime<chrono::Utc>, _rng: &mut impl Rng) -> Vec<Self>
1045+
where
1046+
Self: Sized,
1047+
{
1048+
vec![Self {
1049+
form: FormState::default(),
1050+
}]
1051+
}
1052+
}
1053+
10031054
/// Fields of the account recovery start form
10041055
#[derive(Serialize, Deserialize, Debug, Clone, Copy, Hash, PartialEq, Eq)]
10051056
#[serde(rename_all = "snake_case")]

crates/templates/src/lib.rs

+5
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ pub use self::{
4141
PostAuthContextInner, ReauthContext, ReauthFormField, RecoveryExpiredContext,
4242
RecoveryFinishContext, RecoveryFinishFormField, RecoveryProgressContext,
4343
RecoveryStartContext, RecoveryStartFormField, RegisterContext, RegisterFormField,
44+
RegisterStepsDisplayNameContext, RegisterStepsDisplayNameFormField,
4445
RegisterStepsVerifyEmailContext, RegisterStepsVerifyEmailFormField, SiteBranding,
4546
SiteConfigExt, SiteFeatures, TemplateContext, UpstreamExistingLinkContext,
4647
UpstreamRegister, UpstreamRegisterFormField, UpstreamSuggestLink, WithCaptcha, WithCsrf,
@@ -335,6 +336,9 @@ register_templates! {
335336
/// Render the email verification page
336337
pub fn render_register_steps_verify_email(WithLanguage<WithCsrf<RegisterStepsVerifyEmailContext>>) { "pages/register/steps/verify_email.html" }
337338

339+
/// Render the display name page
340+
pub fn render_register_steps_display_name(WithLanguage<WithCsrf<RegisterStepsDisplayNameContext>>) { "pages/register/steps/display_name.html" }
341+
338342
/// Render the client consent page
339343
pub fn render_consent(WithLanguage<WithCsrf<WithSession<ConsentContext>>>) { "pages/consent.html" }
340344

@@ -428,6 +432,7 @@ impl Templates {
428432
check::render_register(self, now, rng)?;
429433
check::render_password_register(self, now, rng)?;
430434
check::render_register_steps_verify_email(self, now, rng)?;
435+
check::render_register_steps_display_name(self, now, rng)?;
431436
check::render_consent(self, now, rng)?;
432437
check::render_policy_violation(self, now, rng)?;
433438
check::render_sso_login(self, now, rng)?;

templates/components/button.html

+2-1
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
class="",
3030
value="",
3131
disabled=False,
32+
kind="primary",
3233
size="lg",
3334
autocomplete=False,
3435
autocorrect=False,
@@ -39,7 +40,7 @@
3940
type="{{ type }}"
4041
{% if disabled %}disabled{% endif %}
4142
class="cpd-button {{ class }}"
42-
data-kind="primary"
43+
data-kind="{{ kind }}"
4344
data-size="{{ size }}"
4445
{% if autocapitalize %}autocapitilize="{{ autocapitilize }}"{% endif %}
4546
{% if autocomplete %}autocomplete="{{ autocomplete }}"{% endif %}

0 commit comments

Comments
 (0)