Rotate MollySocket device on VAPID change and improve cleanup (#694)

This commit is contained in:
Oscar Mira
2026-02-20 11:21:27 +01:00
10 changed files with 164 additions and 70 deletions

View File

@@ -5,6 +5,7 @@ import com.fasterxml.jackson.core.JsonProcessingException
import com.fasterxml.jackson.databind.JsonMappingException
import im.molly.unifiedpush.model.ConnectionRequest
import im.molly.unifiedpush.model.ConnectionResult
import im.molly.unifiedpush.model.LinkStatus
import im.molly.unifiedpush.model.MollySocketDevice
import im.molly.unifiedpush.model.Response
import okhttp3.HttpUrl
@@ -19,13 +20,11 @@ import org.thoughtcrime.securesms.dependencies.AppDependencies
import org.thoughtcrime.securesms.keyvalue.PhoneNumberPrivacyValues.PhoneNumberDiscoverabilityMode
import org.thoughtcrime.securesms.keyvalue.SignalStore
import org.thoughtcrime.securesms.linkdevice.LinkDeviceRepository
import org.thoughtcrime.securesms.net.SignalNetwork
import org.thoughtcrime.securesms.push.AccountManagerFactory
import org.thoughtcrime.securesms.registration.data.RegistrationRepository
import org.thoughtcrime.securesms.registration.secondary.DeviceNameCipher
import org.thoughtcrime.securesms.util.JsonUtils
import org.thoughtcrime.securesms.util.Util
import org.whispersystems.signalservice.api.NetworkResult
import org.whispersystems.signalservice.api.account.AccountAttributes
import org.whispersystems.signalservice.internal.push.DeviceLimitExceededException
import java.io.IOException
@@ -54,18 +53,8 @@ object MollySocketRepository {
@Throws(IOException::class, DeviceLimitExceededException::class)
private fun verifyNewDevice(password: String): Int {
val verificationCode = when (val result = SignalNetwork.linkDevice.getDeviceVerificationCode()) {
is NetworkResult.Success -> result.result
is NetworkResult.ApplicationError -> throw result.throwable
is NetworkResult.NetworkError -> {
Log.i(TAG, "Network failure", result.getCause())
throw result.exception
}
is NetworkResult.StatusCodeError -> {
Log.i(TAG, "Status code failure", result.getCause())
throw result.exception
}
}
val fetchResult = AppDependencies.linkDeviceApi.getDeviceVerificationCode()
val verificationCode = fetchResult.successOrThrow()
val registrationId = KeyHelper.generateRegistrationId(false)
val encryptedDeviceName = DeviceNameCipher.encryptDeviceName(
@@ -103,11 +92,23 @@ object MollySocketRepository {
}
}
// If loadDevices() fails, optimistically assume the device is linked
fun MollySocketDevice.isLinked(): Boolean {
@Throws(IOException::class)
fun removeDevice(device: MollySocketDevice) {
AppDependencies.linkDeviceApi.removeDevice(device.deviceId).successOrThrow()
}
fun getDeviceStatus(device: MollySocketDevice): LinkStatus {
return when (device.isLinked()) {
true -> LinkStatus.LINKED
false -> LinkStatus.NOT_LINKED
else -> LinkStatus.UNKNOWN
}
}
private fun MollySocketDevice.isLinked(): Boolean? {
return LinkDeviceRepository.loadDevices()?.any {
it.id == deviceId && it.name == DEVICE_NAME
} ?: true
}
}
fun discoverMollySocketServer(url: HttpUrl): Boolean {
@@ -119,10 +120,7 @@ object MollySocketRepository {
Log.d(TAG, "Unexpected code: $response")
return false
}
val body = response.body ?: run {
Log.d(TAG, "No response body")
return false
}
val body = response.body
JsonUtils.fromJson(body.byteStream(), Response::class.java)
}
Log.d(TAG, "URL is OK")
@@ -164,11 +162,7 @@ object MollySocketRepository {
Log.d(TAG, "Unexpected code: $response")
return null
}
val body = response.body ?: run {
Log.d(TAG, "No response body")
return null
}
val body = response.body
val resp = JsonUtils.fromJson(body.byteStream(), Response::class.java)
val status = resp.mollySocket.status

View File

@@ -2,6 +2,7 @@ package im.molly.unifiedpush
import android.content.Context
import org.thoughtcrime.securesms.dependencies.AppDependencies
import org.thoughtcrime.securesms.keyvalue.SignalStore
import org.unifiedpush.android.connector.UnifiedPush
import org.unifiedpush.android.connector.ui.SelectDistributorDialogsBuilder
import org.unifiedpush.android.connector.ui.UnifiedPushFunctions
@@ -16,6 +17,8 @@ object UnifiedPushDistributor {
@JvmStatic
fun unregisterApp() {
UnifiedPush.unregisterApp(AppDependencies.application)
// MessagingReceiver.onUnregistered won't be called after the unregistration request
SignalStore.unifiedpush.endpoint = null
}
fun selectFirstDistributor() {

View File

@@ -90,15 +90,13 @@ class UnifiedPushSettingsFragment : DSLSettingsFragment(R.string.NotificationDel
dividerPref()
if (state.airGapped) {
val parameters = getServerParameters(state) ?: ""
clickPref(
title = DSLSettingsText.from(getString(R.string.UnifiedPushSettingsFragment__server_parameters)),
summary = DSLSettingsText.from(getString(R.string.UnifiedPushSettingsFragment__tap_to_copy_to_clipboard)),
iconEnd = DSLSettingsIcon.from(org.signal.core.ui.R.drawable.symbol_copy_android_24),
isEnabled = parameters.isNotEmpty(),
isEnabled = state.serverParameters != null,
onClick = {
writeTextToClipboard(requireContext(), "Server parameters", parameters)
viewModel.copyParamsToClipboard(requireContext())
},
)
} else {
@@ -108,8 +106,9 @@ class UnifiedPushSettingsFragment : DSLSettingsFragment(R.string.NotificationDel
title = DSLSettingsText.from(getString(R.string.UnifiedPushSettingsFragment__account_id)),
summary = DSLSettingsText.from(aciOrUnknown),
iconEnd = DSLSettingsIcon.from(org.signal.core.ui.R.drawable.symbol_copy_android_24),
isEnabled = state.aci != null,
onClick = {
writeTextToClipboard(requireContext(), "Account ID", aciOrUnknown)
viewModel.copyAciToClipboard(requireContext())
},
)
@@ -125,13 +124,6 @@ class UnifiedPushSettingsFragment : DSLSettingsFragment(R.string.NotificationDel
}
}
private fun getServerParameters(state: UnifiedPushSettingsState): String? {
val aci = state.aci ?: return null
val device = state.device ?: return null
val endpoint = state.endpoint ?: return null
return "connection add $aci ${device.deviceId} ${device.password} $endpoint"
}
@StringRes
private fun getStatusSummary(state: UnifiedPushSettingsState): Int {
return when {

View File

@@ -18,4 +18,11 @@ data class UnifiedPushSettingsState(
val selectedNotAck: Boolean,
val endpoint: String?,
val mollySocketUrl: String?,
)
) {
val serverParameters: String?
get() {
return if (aci != null && device != null && endpoint != null) {
"connection add $aci ${device.deviceId} ${device.password} $endpoint"
} else null
}
}

View File

@@ -1,18 +1,18 @@
package im.molly.unifiedpush.components.settings.app.notifications
import android.app.Application
import android.content.Context
import android.content.pm.PackageManager
import android.os.Build
import android.widget.Toast
import androidx.lifecycle.LiveData
import androidx.lifecycle.ViewModel
import androidx.lifecycle.ViewModelProvider
import im.molly.unifiedpush.model.MollySocket
import org.signal.core.util.ThreadUtil
import org.thoughtcrime.securesms.R
import org.thoughtcrime.securesms.dependencies.AppDependencies
import org.thoughtcrime.securesms.jobs.UnifiedPushRefreshJob
import org.thoughtcrime.securesms.keyvalue.SignalStore
import org.thoughtcrime.securesms.keyvalue.UnifiedPushValues
import org.thoughtcrime.securesms.util.Util.writeTextToClipboard
import org.thoughtcrime.securesms.util.livedata.Store
import org.unifiedpush.android.connector.UnifiedPush
@@ -80,12 +80,7 @@ class UnifiedPushSettingsViewModel(private val application: Application) : ViewM
}
fun setMollySocket(mollySocket: MollySocket) {
SignalStore.unifiedpush.apply {
airGapped = mollySocket is MollySocket.AirGapped
lastReceivedTime = 0
mollySocketUrl = (mollySocket as? MollySocket.WebServer)?.url
mollySocketVapid = mollySocket.vapid
}
SignalStore.unifiedpush.updateMollySocket(mollySocket)
refresh()
updateRegistration()
}
@@ -95,9 +90,34 @@ class UnifiedPushSettingsViewModel(private val application: Application) : ViewM
updateRegistration(pingOnRegister = true)
}
fun copyParamsToClipboard(context: Context) {
val parameters = state.value?.serverParameters ?: ""
writeTextToClipboard(context, "Server parameters", parameters)
}
fun copyAciToClipboard(context: Context) {
val aci = state.value?.aci ?: ""
writeTextToClipboard(context, "Account ID", aci)
}
class Factory(private val application: Application) : ViewModelProvider.Factory {
override fun <T : ViewModel> create(modelClass: Class<T>): T {
return requireNotNull(modelClass.cast(UnifiedPushSettingsViewModel(application)))
}
}
}
fun UnifiedPushValues.updateMollySocket(mollySocket: MollySocket) {
airGapped = mollySocket is MollySocket.AirGapped
mollySocketUrl = (mollySocket as? MollySocket.WebServer)?.url
val changed = vapidPublicKey != mollySocket.vapid
vapidPublicKey = mollySocket.vapid
if (changed) {
vapidKeySynced = false
endpoint = null
}
lastReceivedTime = 0
}

View File

@@ -9,16 +9,25 @@ data class MollySocketDevice(
}
}
enum class LinkStatus {
LINKED,
NOT_LINKED,
UNKNOWN,
}
enum class RegistrationStatus(val value: Int) {
UNKNOWN(0),
PENDING(1),
REGISTERED(2),
BAD_RESPONSE(3),
SERVER_ERROR(4),
/** The UUID is forbidden by the config of MollySocket */
FORBIDDEN_UUID(5),
/** The endpoint is forbidden by the config of MollySocket */
FORBIDDEN_ENDPOINT(6),
/** The account+password doesn't work anymore, and returns forbidden by Signal server */
FORBIDDEN_PASSWORD(7);
@@ -29,7 +38,7 @@ enum class RegistrationStatus(val value: Int) {
}
}
fun ConnectionResult?.toRegistrationStatus():RegistrationStatus = when (this) {
fun ConnectionResult?.toRegistrationStatus(): RegistrationStatus = when (this) {
ConnectionResult.OK -> RegistrationStatus.REGISTERED
ConnectionResult.INTERNAL_ERROR -> RegistrationStatus.SERVER_ERROR
ConnectionResult.FORBIDDEN -> RegistrationStatus.FORBIDDEN_PASSWORD

View File

@@ -581,8 +581,10 @@ public class ApplicationContext extends Application implements AppForegroundObse
private void updateUnifiedPushStatus(boolean enabled) {
SignalStore.unifiedpush().setEnabled(enabled);
if (enabled) {
UnifiedPushDistributor.registerApp(SignalStore.unifiedpush().getMollySocketVapid());
} else {
UnifiedPushDistributor.registerApp(SignalStore.unifiedpush().getVapidPublicKey());
} else if (!SignalStore.unifiedpush().getAirGapped()) {
// Delete registration only if it isn't air gapped,
// When air gapped, we want to avoid unnecessary endpoint rotation
UnifiedPushDistributor.unregisterApp();
}
AppDependencies.getJobManager().add(new UnifiedPushRefreshJob());

View File

@@ -6,6 +6,7 @@ import android.os.Build
import androidx.lifecycle.ViewModel
import androidx.lifecycle.ViewModelProvider
import im.molly.unifiedpush.UnifiedPushDistributor
import im.molly.unifiedpush.components.settings.app.notifications.updateMollySocket
import im.molly.unifiedpush.model.MollySocket
import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.StateFlow
@@ -126,12 +127,7 @@ class NotificationsSettingsViewModel(private val sharedPreferences: SharedPrefer
}
fun initializeMollySocket(mollySocket: MollySocket) {
SignalStore.unifiedpush.apply {
airGapped = mollySocket is MollySocket.AirGapped
lastReceivedTime = 0
mollySocketUrl = (mollySocket as? MollySocket.WebServer)?.url
mollySocketVapid = mollySocket.vapid
}
SignalStore.unifiedpush.updateMollySocket(mollySocket)
}
fun setPlayServicesErrorCode(errorCode: Int?) {

View File

@@ -1,9 +1,10 @@
package org.thoughtcrime.securesms.jobs
import im.molly.unifiedpush.MollySocketRepository
import im.molly.unifiedpush.MollySocketRepository.isLinked
import im.molly.unifiedpush.UnifiedPushDistributor
import im.molly.unifiedpush.UnifiedPushNotificationBuilder
import im.molly.unifiedpush.model.LinkStatus
import im.molly.unifiedpush.model.MollySocketDevice
import im.molly.unifiedpush.model.RegistrationStatus
import im.molly.unifiedpush.model.toRegistrationStatus
import okhttp3.HttpUrl.Companion.toHttpUrlOrNull
@@ -23,6 +24,11 @@ import java.util.concurrent.TimeUnit
/**
* Handles UnifiedPush registration and ensures the MollySocket status is up-to-date.
* Unregisters if the account is not registered or UnifiedPush is disabled.
*
* @param testPing Send a ping notification with the registration.
* Used during first registration to a new MollySocket server.
* @param fromNewEndpoint Avoid registration cycle: if the Job is run from
* the receiver "onNewEndpoint", it will not register to the distributor.
*/
class UnifiedPushRefreshJob private constructor(
private val testPing: Boolean,
@@ -54,14 +60,21 @@ class UnifiedPushRefreshJob private constructor(
Log.d(TAG, "Current registration status: $currentStatus")
if (!hasAccount || !enabled) {
if (!hasAccount) {
Log.w(TAG, "User is not registered. Skipping.")
return
}
if (!enabled) {
Log.d(TAG, "UnifiedPush is disabled.")
cleanupMollySocketDevice()
return
}
try {
val newStatus = checkRegistrationStatus()
validateMollySocketDevice()
val newStatus = checkRegistrationStatus()
if (currentStatus == newStatus) {
Log.d(TAG, "Registration status unchanged.")
} else {
@@ -100,15 +113,13 @@ class UnifiedPushRefreshJob private constructor(
val airGapped = SignalStore.unifiedpush.airGapped
val mollySocketUrl = SignalStore.unifiedpush.mollySocketUrl?.toHttpUrlOrNull()
val lastReceivedTime = SignalStore.unifiedpush.lastReceivedTime
val vapid = SignalStore.unifiedpush.mollySocketVapid
val vapid = SignalStore.unifiedpush.vapidPublicKey
Log.d(TAG, "Last notification received at: $lastReceivedTime")
SignalStore.unifiedpush.device?.let { device ->
if (!device.isLinked()) {
Log.w(TAG, "$device no longer linked, will be recreated.")
SignalStore.unifiedpush.device = null
}
if (vapid == null) {
Log.e(TAG, "Missing VAPID public key.")
return RegistrationStatus.PENDING
}
if (!fromNewEndpoint) {
@@ -126,9 +137,10 @@ class UnifiedPushRefreshJob private constructor(
}
val device = SignalStore.unifiedpush.device ?: try {
MollySocketRepository.createDevice().also { device ->
SignalStore.unifiedpush.device = device
Log.d(TAG, "Created new MollySocket device: $device")
MollySocketRepository.createDevice().also { newDevice ->
SignalStore.unifiedpush.device = newDevice
SignalStore.unifiedpush.vapidKeySynced = true
Log.d(TAG, "Created new MollySocket device: $newDevice")
}
} catch (e: DeviceLimitExceededException) {
Log.e(TAG, "Device limit exceeded: ${e.max} total devices already.")
@@ -160,6 +172,62 @@ class UnifiedPushRefreshJob private constructor(
return result.toRegistrationStatus()
}
@Throws(IOException::class)
private fun validateMollySocketDevice() {
val device = SignalStore.unifiedpush.device ?: return
val vapidSynced = SignalStore.unifiedpush.vapidKeySynced
when (MollySocketRepository.getDeviceStatus(device)) {
LinkStatus.LINKED -> {
if (!vapidSynced) {
Log.w(TAG, "VAPID key mismatch, will remove previous linked $device.")
removeAndClearDevice(device)
}
}
LinkStatus.NOT_LINKED -> {
Log.w(TAG, "$device no longer linked, will be recreated.")
clearDevice()
}
LinkStatus.UNKNOWN -> {
if (!vapidSynced) {
throw IOException("VAPID key mismatch, but cannot determine $device link status.")
}
return // Optimistically assume the device is linked
}
}
}
@Throws(IOException::class)
private fun cleanupMollySocketDevice() {
if (SignalStore.unifiedpush.airGapped) {
Log.d(TAG, "MollySocket is air-gapped, cleanup skipped.")
return
}
val device = SignalStore.unifiedpush.device ?: return
Log.d(TAG, "Cleaning up non-air-gapped $device...")
when (MollySocketRepository.getDeviceStatus(device)) {
LinkStatus.LINKED -> removeAndClearDevice(device)
LinkStatus.NOT_LINKED -> clearDevice()
LinkStatus.UNKNOWN -> {
throw IOException("Cannot determine $device link status during cleanup.")
}
}
}
@Throws(IOException::class)
private fun removeAndClearDevice(device: MollySocketDevice) {
MollySocketRepository.removeDevice(device)
clearDevice()
}
private fun clearDevice() {
SignalStore.unifiedpush.device = null
}
override fun onFailure() = Unit
public override fun onShouldRetry(throwable: Exception): Boolean {

View File

@@ -15,6 +15,7 @@ class UnifiedPushValues(store: KeyValueStore) : SignalStoreValues(store) {
private const val MOLLYSOCKET_AIR_GAPPED = "mollysocket.airGapped"
private const val MOLLYSOCKET_URL = "mollysocket.url"
private const val MOLLYSOCKET_VAPID = "mollysocket.vapid"
private const val MOLLYSOCKET_VAPID_SYNCED = "mollysocket.vapid.synced"
private const val UNIFIEDPUSH_ENABLED = "up.enabled"
private const val UNIFIEDPUSH_ENDPOINT = "up.endpoint"
private const val UNIFIEDPUSH_LAST_RECEIVED_TIME = "up.lastRecvTime"
@@ -56,7 +57,9 @@ class UnifiedPushValues(store: KeyValueStore) : SignalStoreValues(store) {
var mollySocketUrl: String? by stringValue(MOLLYSOCKET_URL, null)
var mollySocketVapid: String? by stringValue(MOLLYSOCKET_VAPID, null)
var vapidPublicKey: String? by stringValue(MOLLYSOCKET_VAPID, null)
var vapidKeySynced: Boolean by booleanValue(MOLLYSOCKET_VAPID_SYNCED, true)
var lastReceivedTime: Long by longValue(UNIFIEDPUSH_LAST_RECEIVED_TIME, 0)