diff --git a/app/src/main/java/com/protonvpn/android/servers/ServerManager2.kt b/app/src/main/java/com/protonvpn/android/servers/ServerManager2.kt index 4b7e54d15..345e55188 100644 --- a/app/src/main/java/com/protonvpn/android/servers/ServerManager2.kt +++ b/app/src/main/java/com/protonvpn/android/servers/ServerManager2.kt @@ -57,8 +57,6 @@ class ServerManager2 @Inject constructor( emitAll(serverManager.serverListVersion) } - /** The first value is emitted before servers are loaded */ - val isDownloadedAtLeastOnceFlow = serverManager.isDownloadedAtLeastOnceFlow /** The first value is emitted before servers are loaded */ val hasAnyCountryFlow = serverManager.hasCountriesFlow /** The first value is emitted before servers are loaded */ diff --git a/app/src/main/java/com/protonvpn/android/servers/ServersDataManager.kt b/app/src/main/java/com/protonvpn/android/servers/ServersDataManager.kt index 90668c20a..ce1f96b15 100644 --- a/app/src/main/java/com/protonvpn/android/servers/ServersDataManager.kt +++ b/app/src/main/java/com/protonvpn/android/servers/ServersDataManager.kt @@ -20,13 +20,17 @@ package com.protonvpn.android.servers import com.protonvpn.android.concurrency.VpnDispatcherProvider -import com.protonvpn.android.servers.api.ConnectingDomain +import com.protonvpn.android.di.WallClock import com.protonvpn.android.models.vpn.GatewayGroup -import com.protonvpn.android.servers.api.LoadUpdate import com.protonvpn.android.models.vpn.VpnCountry +import com.protonvpn.android.servers.api.ConnectingDomain +import com.protonvpn.android.servers.api.LoadUpdate import com.protonvpn.android.servers.api.LogicalsStatusId import com.protonvpn.android.utils.replace import kotlinx.coroutines.async +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.MutableStateFlow +import kotlinx.coroutines.flow.filterNotNull import kotlinx.coroutines.sync.Mutex import kotlinx.coroutines.sync.withLock import kotlinx.coroutines.withContext @@ -38,43 +42,46 @@ class ServersDataManager @Inject constructor( private val dispatcherProvider: VpnDispatcherProvider, private val serversStore: ServersStore, private val updateServersWithBinaryStatus: UpdateServersWithBinaryStatus, + @WallClock private val wallClock: () -> Long, ) { - private data class ServerLists( + data class ServerLists( val allServers: List, val allServersByScore: List, val vpnCountries: List, val secureCoreExitCountries: List, val gateways: List, val statusId: LogicalsStatusId? = null, - ) + ) { + companion object { + val Empty = ServerLists( + allServers = emptyList(), + allServersByScore = emptyList(), + vpnCountries = emptyList(), + secureCoreExitCountries = emptyList(), + gateways = emptyList(), + statusId = null, + ) + } + } + + private val serverListsFlow = MutableStateFlow(null) + val serverLists: Flow = serverListsFlow.filterNotNull() + var lastUpdateTimestamp: Long = 0 + private set private data class UpdateResult( val statusId: LogicalsStatusId?, val servers: List, + val serverListUpdateTimestamp: Long?, ) // Protect all modifications with the mutex. Use updateWithMutex for common cases. private val updateMutex = Mutex() - private var serverLists = ServerLists( - allServers = emptyList(), - allServersByScore = emptyList(), - vpnCountries = emptyList(), - secureCoreExitCountries = emptyList(), - gateways = emptyList(), - statusId = null, - ) - - val statusId: LogicalsStatusId? get() = serverLists.statusId - val allServers: List get() = serverLists.allServers - val allServersByScore: List get() = serverLists.allServersByScore - val vpnCountries: List get() = serverLists.vpnCountries - val secureCoreExitCountries: List get() = serverLists.secureCoreExitCountries - val gateways: List get() = serverLists.gateways // Load servers from storage. Returns true if servers were loaded successfully. suspend fun load(): Boolean { val loaded = serversStore.load() - updateWithMutex(saveToStorage = false) { with(serversStore) { UpdateResult(serversStatusId, allServers) } } + updateWithMutex(saveToStorage = false) { with(serversStore) { UpdateResult(serversStatusId, allServers, lastUpdateTimestamp) } } return loaded } @@ -87,17 +94,18 @@ class ServersDataManager @Inject constructor( if (server.serverId in retainIDs) missingServerIDs.remove(server.serverId) } - val retainedServers = allServers.filter { it.serverId in missingServerIDs } - UpdateResult(newStatusId, serverList + retainedServers) + val retainedServers = currentServers().allServers.filter { it.serverId in missingServerIDs } + UpdateResult(newStatusId, serverList + retainedServers, wallClock()) } } else { - UpdateResult(newStatusId, serverList) + UpdateResult(newStatusId, serverList, wallClock()) } } } suspend fun updateServerDomainStatus(connectingDomain: ConnectingDomain) { updateWithMutex { + val allServers = currentServers().allServers val updatedServers = buildList(allServers.size) { allServers.forEach { currentServer -> val server = if (currentServer.connectingDomains.any { it.id == connectingDomain.id }) { @@ -112,15 +120,15 @@ class ServersDataManager @Inject constructor( add(server) } } - UpdateResult(statusId, updatedServers) + UpdateResult(currentServers().statusId, updatedServers, serverListUpdateTimestamp = null) } } suspend fun updateLoads(loadsList: List) { updateWithMutex { val loadsMap: Map = loadsList.associateBy { it.id } - val updatedServers = buildList(allServers.size) { - allServers.forEach { currentServer -> + val updatedServers = buildList(currentServers().allServers.size) { + currentServers().allServers.forEach { currentServer -> val newValues = loadsMap[currentServer.serverId] val server = if (newValues != null) { // Status update doesn't include physical servers, it's not safe to go from @@ -137,15 +145,23 @@ class ServersDataManager @Inject constructor( add(server) } } - UpdateResult(statusId, updatedServers) + UpdateResult(currentServers().statusId, updatedServers, serverListUpdateTimestamp = null) } } suspend fun updateBinaryLoads(statusId: LogicalsStatusId, statusData: ByteArray) { updateWithMutex { - if (statusId != this.statusId) return@updateWithMutex null + if (statusId != currentServers().statusId) return@updateWithMutex null val updatedServers = updateServersWithBinaryStatus(serversStore.allServers, statusData) - updatedServers?.let { UpdateResult(statusId, updatedServers) } + updatedServers?.let { + UpdateResult(statusId, updatedServers, serverListUpdateTimestamp = null) + } + } + } + + suspend fun updateLastUpdateTimestamp(timestamp: Long = wallClock()) { + updateWithMutex { + UpdateResult(currentServers().statusId, currentServers().allServers, timestamp) } } @@ -153,11 +169,12 @@ class ServersDataManager @Inject constructor( updateWithMutex { withContext(dispatcherProvider.Comp) { UpdateResult( - statusId, - allServers.toMutableList().apply { + currentServers().statusId, + currentServers().allServers.toMutableList().apply { removeIf { it.serverId == server.serverId } add(server) - } + }, + wallClock() ) } } @@ -168,25 +185,32 @@ class ServersDataManager @Inject constructor( updateBlock: suspend () -> UpdateResult?, ) { updateMutex.withLock { - val updateResult: Pair, ServerLists>? = withContext(dispatcherProvider.Comp) { + val updateResult: Triple, ServerLists, Long?>? = withContext(dispatcherProvider.Comp) { val update = updateBlock() ?: return@withContext null val newServers = update.servers.filter { it.isVisible } val groupedServers = async { updateServerLists(newServers, update.statusId) } val sortedServers = async { newServers.sortedBy { it.score } } val newServerLists = groupedServers.await() .copy(allServersByScore = sortedServers.await()) - update.servers to newServerLists + Triple(update.servers,newServerLists, update.serverListUpdateTimestamp) } if (updateResult != null) { - val (allServers, newServerLists) = updateResult - if (saveToStorage) { - serversStore.save(allServers, newServerLists.statusId) + val (allServers, newServerLists, updateTimestamp) = updateResult + if (updateTimestamp != null) { + lastUpdateTimestamp = updateTimestamp } - serverLists = newServerLists + if (saveToStorage) { + serversStore.save(allServers, newServerLists.statusId, lastUpdateTimestamp) + } + + serverListsFlow.value = newServerLists } } } + // Use only when protected by the updateMutex. + private fun currentServers() = serverListsFlow.value ?: ServerLists.Empty + companion object { private fun updateServerLists(newServerList: List, statusId: LogicalsStatusId?): ServerLists { fun MutableMap>.addServer( diff --git a/app/src/main/java/com/protonvpn/android/servers/ServersStore.kt b/app/src/main/java/com/protonvpn/android/servers/ServersStore.kt index 5596b39a1..1af8a132a 100644 --- a/app/src/main/java/com/protonvpn/android/servers/ServersStore.kt +++ b/app/src/main/java/com/protonvpn/android/servers/ServersStore.kt @@ -38,11 +38,14 @@ class ServersStore( private set var allServers: List = emptyList() private set + var lastUpdateTimestamp: Long = 0 + private set suspend fun load(): Boolean { val data = store.read() if (data != null) { serversStatusId = data.statusFileId + lastUpdateTimestamp = data.updateTimestamp allServers = if (data.allServers.isEmpty() && data.vpnCountries.isNotEmpty()) { extractServers(data.vpnCountries, data.secureCoreEntryCountries, data.secureCoreExitCountries) } else { @@ -52,10 +55,11 @@ class ServersStore( return data != null } - fun save(newServers: List, newStatusId: LogicalsStatusId?) { + fun save(newServers: List, newStatusId: LogicalsStatusId?, updateTimestamp: Long) { serversStatusId = newStatusId allServers = newServers - val data = ServersSerializationData(allServers, serversStatusId) + lastUpdateTimestamp = updateTimestamp + val data = ServersSerializationData(allServers, serversStatusId, updateTimestamp) store.store(data) } @@ -94,6 +98,7 @@ class ServersStore( class ServersSerializationData( val allServers: List = emptyList(), val statusFileId: String? = null, + val updateTimestamp: Long = 0, // Deprecated, used only for migration. val vpnCountries: List = emptyList(), diff --git a/app/src/main/java/com/protonvpn/android/servers/UpdateServerListFromApi.kt b/app/src/main/java/com/protonvpn/android/servers/UpdateServerListFromApi.kt index 8cd639232..570ba1cc0 100644 --- a/app/src/main/java/com/protonvpn/android/servers/UpdateServerListFromApi.kt +++ b/app/src/main/java/com/protonvpn/android/servers/UpdateServerListFromApi.kt @@ -48,8 +48,8 @@ import javax.inject.Inject class UpdateServerListFromApi @Inject constructor( private val api: ProtonApiRetroFit, private val dispatcherProvider: DispatcherProvider, - @WallClock private val wallClock: () -> Long, private val serverManager: ServerManager, + private val serversDataManager: ServersDataManager, private val prefs: ServerListUpdaterPrefs, private val updateWithBinaryStatus: UpdateServersWithBinaryStatus, private val binaryServerStatusEnabled: IsBinaryServerStatusEnabled, @@ -137,7 +137,7 @@ class UpdateServerListFromApi @Inject constructor( is FetchResult.NewServers, is FetchResult.NotModified -> { - serverManager.updateTimestamp() + serversDataManager.updateLastUpdateTimestamp() PeriodicActionResult(Result.Success, isSuccess = true) } } diff --git a/app/src/main/java/com/protonvpn/android/ui/home/ServerListUpdater.kt b/app/src/main/java/com/protonvpn/android/ui/home/ServerListUpdater.kt index c42b0bb58..4ccd920a7 100644 --- a/app/src/main/java/com/protonvpn/android/ui/home/ServerListUpdater.kt +++ b/app/src/main/java/com/protonvpn/android/ui/home/ServerListUpdater.kt @@ -39,6 +39,7 @@ import com.protonvpn.android.logging.LogCategory import com.protonvpn.android.logging.ProtonLogger import com.protonvpn.android.models.vpn.UserLocation import com.protonvpn.android.servers.IsBinaryServerStatusEnabled +import com.protonvpn.android.servers.ServersDataManager import com.protonvpn.android.servers.UpdateServerListFromApi import com.protonvpn.android.servers.api.ServersCountResponse import com.protonvpn.android.utils.ServerManager @@ -86,6 +87,7 @@ class ServerListUpdater @Inject constructor( private val scope: CoroutineScope, private val api: ProtonApiRetroFit, private val serverManager: ServerManager, + private val serversDataManager: ServersDataManager, private val currentUser: CurrentUser, private val vpnStateMonitor: VpnStateMonitor, userPlanManager: UserPlanManager, @@ -118,7 +120,7 @@ class ServerListUpdater @Inject constructor( ) suspend fun needsUpdate() = serverManager.needsUpdate() || - wallClock() - serverManager.lastUpdateTimestamp >= 4 * remoteConfig.first().foregroundDelayMs + wallClock() - serversDataManager.lastUpdateTimestamp >= 4 * remoteConfig.first().foregroundDelayMs init { migrateIpAddress() diff --git a/app/src/main/java/com/protonvpn/android/utils/ServerManager.kt b/app/src/main/java/com/protonvpn/android/utils/ServerManager.kt index a3214f964..db8cfa6a7 100644 --- a/app/src/main/java/com/protonvpn/android/utils/ServerManager.kt +++ b/app/src/main/java/com/protonvpn/android/utils/ServerManager.kt @@ -23,7 +23,6 @@ import com.protonvpn.android.BuildConfig import com.protonvpn.android.appconfig.UserCountryPhysical import com.protonvpn.android.auth.data.VpnUser import com.protonvpn.android.auth.data.hasAccessToServer -import com.protonvpn.android.di.WallClock import com.protonvpn.android.excludedlocations.ExcludedLocations import com.protonvpn.android.logging.ProtonLogger import com.protonvpn.android.models.profiles.Profile @@ -40,6 +39,7 @@ import com.protonvpn.android.redesign.vpn.isExcluded import com.protonvpn.android.redesign.vpn.satisfiesFeatures import com.protonvpn.android.servers.Server import com.protonvpn.android.servers.ServersDataManager +import com.protonvpn.android.servers.ServersDataManager.ServerLists import com.protonvpn.android.servers.ServersResult import com.protonvpn.android.servers.api.ConnectingDomain import com.protonvpn.android.servers.api.LoadUpdate @@ -47,9 +47,13 @@ import com.protonvpn.android.servers.api.LogicalsStatusId import com.protonvpn.android.vpn.ProtocolSelection import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.flow.MutableStateFlow +import kotlinx.coroutines.flow.SharingStarted import kotlinx.coroutines.flow.distinctUntilChanged import kotlinx.coroutines.flow.first +import kotlinx.coroutines.flow.launchIn import kotlinx.coroutines.flow.map +import kotlinx.coroutines.flow.onEach +import kotlinx.coroutines.flow.stateIn import kotlinx.coroutines.launch import kotlinx.serialization.Serializable import javax.inject.Inject @@ -58,6 +62,7 @@ import javax.inject.Singleton @Serializable data class ServerManagerState( val serverListAppVersionCode: Int = 0, + @Deprecated("The timestamp is now stored by ServersDataManager.") val lastUpdateTimestamp: Long = 0L, /** Can be checked even before servers are loaded from storage */ @@ -68,9 +73,8 @@ data class ServerManagerState( @Deprecated("Use ServerManager2 in new code") @Singleton class ServerManager @Inject constructor( - private val mainScope: CoroutineScope, - @WallClock private val wallClock: () -> Long, - val serversData: ServersDataManager, + mainScope: CoroutineScope, + val serversDataManager: ServersDataManager, val physicalUserCountry: UserCountryPhysical, ) { private var savedState: ServerManagerState @@ -82,23 +86,25 @@ class ServerManager @Inject constructor( // combine to react to updates. val serverListVersion = MutableStateFlow(0) - val lastUpdateTimestamp get() = savedState.lastUpdateTimestamp - + val isDownloadedAtLeastOnce get() = serversDataManager.lastUpdateTimestamp > 0 /** Can be checked even before servers are loaded from storage */ - val isDownloadedAtLeastOnce get() = savedState.lastUpdateTimestamp > 0 - val isDownloadedAtLeastOnceFlow = serverListVersion.map { isDownloadedAtLeastOnce }.distinctUntilChanged() val hasCountriesFlow = serverListVersion.map { savedState.hasCountries }.distinctUntilChanged() val hasGatewaysFlow = serverListVersion.map { savedState.hasGateways }.distinctUntilChanged() + // The cached values are to be used only by legacy code. + private val serversDataCachedFlow = + serversDataManager.serverLists.stateIn(mainScope, SharingStarted.Eagerly, ServerLists.Empty) + private val serversDataCached get() = serversDataCachedFlow.value // May be empty. + suspend fun needsUpdate(): Boolean { - ensureLoaded() - return savedState.lastUpdateTimestamp == 0L || serversData.allServers.isEmpty() || + val serversData = serversDataManager.serverLists.first() + return serversDataManager.lastUpdateTimestamp == 0L || serversData.allServers.isEmpty() || !haveWireGuardSupport() || savedState.serverListAppVersionCode < BuildConfig.VERSION_CODE } - val logicalsStatusId get() = serversData.statusId - val allServers get() = serversData.allServers - val allServersByScore get() = serversData.allServersByScore + val logicalsStatusId get() = serversDataCached.statusId + val allServers get() = serversDataCached.allServers + val allServersByScore get() = serversDataCached.allServersByScore val freeCountries get() = getVpnCountries() @@ -114,7 +120,7 @@ class ServerManager @Inject constructor( } mainScope.launch { - val loaded = serversData.load() + val loaded = serversDataManager.load() if (!loaded) { // We had servers saved but failed to load them, reset state. updateAndSave( @@ -124,12 +130,26 @@ class ServerManager @Inject constructor( hasCountries = false, ) ) + } else if ( + serversDataManager.lastUpdateTimestamp == 0L && + savedState.lastUpdateTimestamp > 0 + ) { + serversDataManager.updateLastUpdateTimestamp(savedState.lastUpdateTimestamp) + updateAndSave(savedState.copy(lastUpdateTimestamp = 0L)) } - // Notify of loaded state and update after everything has been updated. isLoaded.value = true - onServersUpdate() } + serversDataManager.serverLists.onEach { serversData -> + updateAndSave( + savedState.copy( + serverListAppVersionCode = BuildConfig.VERSION_CODE, + hasGateways = serversData.gateways.isNotEmpty(), + hasCountries = serversData.vpnCountries.isNotEmpty(), + ) + ) + serverListVersion.value++ + }.launchIn(mainScope) } suspend fun ensureLoaded() { @@ -137,18 +157,14 @@ class ServerManager @Inject constructor( } fun getExitCountries(secureCore: Boolean) = if (secureCore) - serversData.secureCoreExitCountries else serversData.vpnCountries - - private fun onServersUpdate() { - ++serverListVersion.value - } + serversDataCached.secureCoreExitCountries else serversDataCached.vpnCountries override fun toString(): String { - val lastUpdateTimestampLog = savedState.lastUpdateTimestamp + val lastUpdateTimestampLog = serversDataManager.lastUpdateTimestamp .takeIf { it != 0L } ?.let { ProtonLogger.formatTime(it) } - return "vpnCountries: ${serversData.vpnCountries.size} gateways: ${serversData.gateways.size}" + - " exit: ${serversData.secureCoreExitCountries.size} " + + return "vpnCountries: ${serversDataCached.vpnCountries.size} gateways: ${serversDataCached.gateways.size}" + + " exit: ${serversDataCached.secureCoreExitCountries.size} " + "ServerManager Updated: $lastUpdateTimestampLog" } @@ -190,55 +206,35 @@ class ServerManager @Inject constructor( retainIDs: Set = emptySet() ) { ensureLoaded() - serversData.replaceServers(serverList, statusId, retainIDs) - - updateAndSave( - savedState.copy( - lastUpdateTimestamp = wallClock(), - serverListAppVersionCode = BuildConfig.VERSION_CODE, - hasGateways = serversData.gateways.isNotEmpty(), - hasCountries = serversData.vpnCountries.isNotEmpty(), - ) - ) - onServersUpdate() - } - - fun updateTimestamp() { - updateAndSave( - savedState.copy(lastUpdateTimestamp = wallClock()) - ) + serversDataManager.replaceServers(serverList, statusId, retainIDs) } suspend fun updateServerDomainStatus(connectingDomain: ConnectingDomain) { ensureLoaded() - serversData.updateServerDomainStatus(connectingDomain) - onServersUpdate() + serversDataManager.updateServerDomainStatus(connectingDomain) } suspend fun updateLoads(loadsList: List) { ensureLoaded() - serversData.updateLoads(loadsList) - onServersUpdate() + serversDataManager.updateLoads(loadsList) } suspend fun updateBinaryLoads(statusId: LogicalsStatusId, loads: ByteArray) { ensureLoaded() - serversData.updateBinaryLoads(statusId, loads) - onServersUpdate() + serversDataManager.updateBinaryLoads(statusId, loads) } suspend fun updateOrAddServer(server: Server) { ensureLoaded() - serversData.updateOrAddServer(server) - onServersUpdate() + serversDataManager.updateOrAddServer(server) } fun getServerById(id: String) = allServers.firstOrNull { it.serverId == id } ?: guestHoleServers?.firstOrNull { it.serverId == id } - fun getVpnCountries(): List = serversData.vpnCountries.sortedByLocaleAware { it.countryName } + fun getVpnCountries(): List = serversDataCached.vpnCountries.sortedByLocaleAware { it.countryName } - fun getGateways(): List = serversData.gateways + fun getGateways(): List = serversDataCached.gateways @Deprecated("Use the suspending getVpnExitCountry from ServerManager2") fun getVpnExitCountry(countryCode: String, secureCoreCountry: Boolean): VpnCountry? = @@ -283,7 +279,7 @@ class ServerManager @Inject constructor( } fun getSecureCoreExitCountries(): List = - serversData.secureCoreExitCountries.sortedByLocaleAware { it.countryName } + serversDataCached.secureCoreExitCountries.sortedByLocaleAware { it.countryName } @Deprecated("Use getServerForConnectIntent") fun getServerForProfile( @@ -475,7 +471,7 @@ class ServerManager @Inject constructor( } private fun haveWireGuardSupport() = - serversData.allServers.any { server -> server.connectingDomains.any { it.publicKeyX25519 != null } } + serversDataCached.allServers.any { server -> server.connectingDomains.any { it.publicKeyX25519 != null } } private fun Server?.handleServersResult( onServersResult: (ServersResult) -> T, diff --git a/app/src/test/java/com/protonvpn/app/servers/ServerManagerStorageTests.kt b/app/src/test/java/com/protonvpn/app/servers/ServerManagerStorageTests.kt index e9ec53d6d..6a87b8026 100644 --- a/app/src/test/java/com/protonvpn/app/servers/ServerManagerStorageTests.kt +++ b/app/src/test/java/com/protonvpn/app/servers/ServerManagerStorageTests.kt @@ -35,6 +35,7 @@ import kotlinx.coroutines.test.StandardTestDispatcher import kotlinx.coroutines.test.TestDispatcher import kotlinx.coroutines.test.TestScope import kotlinx.coroutines.test.currentTime +import kotlinx.coroutines.test.runCurrent import kotlinx.coroutines.test.runTest import org.junit.Assert.assertEquals import org.junit.Assert.assertFalse @@ -62,6 +63,7 @@ class ServerManagerStorageTests { testDispatcherProvider, createInMemoryServersStore(emptyList(), ""), FakeUpdateServersWithBinaryStatus(), + testScope::currentTime, ) } @@ -75,14 +77,14 @@ class ServerManagerStorageTests { } val serverManager = ServerManager( mainScope = testScope.backgroundScope, - wallClock = testScope::currentTime, - serversData = serversDataManager, - createNoopUserCountry(), + serversDataManager = serversDataManager, + physicalUserCountry = createNoopUserCountry(), ) - - assertTrue(serverManager.isDownloadedAtLeastOnce) - assertEquals(1769690569069L, serverManager.lastUpdateTimestamp) assertTrue(serverManager.hasCountriesFlow.first()) assertFalse(serverManager.hasGatewaysFlow.first()) + + serverManager.ensureLoaded() + assertTrue(serverManager.isDownloadedAtLeastOnce) + assertEquals(1769690569069L, serversDataManager.lastUpdateTimestamp) } } \ No newline at end of file diff --git a/app/src/test/java/com/protonvpn/app/ui/home/ServerListUpdaterTests.kt b/app/src/test/java/com/protonvpn/app/ui/home/ServerListUpdaterTests.kt index 2f6911223..d9f8bedbb 100644 --- a/app/src/test/java/com/protonvpn/app/ui/home/ServerListUpdaterTests.kt +++ b/app/src/test/java/com/protonvpn/app/ui/home/ServerListUpdaterTests.kt @@ -30,6 +30,7 @@ import com.protonvpn.android.models.vpn.data.LogicalsMetadata import com.protonvpn.android.servers.FakeIsBinaryServerStatusFeatureFlagEnabled import com.protonvpn.android.servers.IsBinaryServerStatusEnabled import com.protonvpn.android.servers.Server +import com.protonvpn.android.servers.ServersDataManager import com.protonvpn.android.servers.UpdateServerListFromApi import com.protonvpn.android.servers.api.LogicalServer import com.protonvpn.android.servers.api.LogicalServerV1 @@ -49,6 +50,7 @@ import com.protonvpn.android.vpn.usecases.FakeServerListTruncationEnabled import com.protonvpn.android.vpn.usecases.GetTruncationMustHaveIDs import com.protonvpn.mocks.FakeUpdateServersWithBinaryStatus import com.protonvpn.mocks.createInMemoryServerManager +import com.protonvpn.mocks.createInMemoryServersDataManager import com.protonvpn.test.shared.MockSharedPreference import com.protonvpn.test.shared.MockSharedPreferencesProvider import com.protonvpn.test.shared.MockedServers @@ -128,6 +130,7 @@ class ServerListUpdaterTests { private lateinit var vpnStateMonitor: VpnStateMonitor private lateinit var mustHaveIDs: Set private lateinit var binaryStatusFfEnabled: MutableStateFlow + private lateinit var serversDataManager: ServersDataManager private lateinit var serverManager: ServerManager private lateinit var truncationEnabled: MutableStateFlow private lateinit var runWhileGettingServerList: () -> Unit @@ -186,7 +189,8 @@ class ServerListUpdaterTests { mustHaveIDs = emptySet() binaryStatusFfEnabled = MutableStateFlow(false) - serverManager = createInMemoryServerManager(testScope, TestDispatcherProvider(testDispatcher), initialServers = emptyList()) + serversDataManager = createInMemoryServersDataManager(testScope, TestDispatcherProvider(testDispatcher)) + serverManager = createInMemoryServerManager(testScope, serversDataManager) truncationEnabled = MutableStateFlow(true) val getNetZone = GetNetZone(serverListUpdaterPrefs) val serverListTruncationFF = FakeServerListTruncationEnabled(truncationEnabled) @@ -198,8 +202,8 @@ class ServerListUpdaterTests { val updateServerListFromApi = UpdateServerListFromApi( mockApi, TestDispatcherProvider(testDispatcher), - testScope::currentTime, serverManager, + serversDataManager, serverListUpdaterPrefs, fakeUpdateWithBinaryStatus, binaryServerStatusEnabled, @@ -210,6 +214,7 @@ class ServerListUpdaterTests { scope = testScope.backgroundScope, api = mockApi, serverManager = serverManager, + serversDataManager = serversDataManager, currentUser = mockCurrentUser, vpnStateMonitor = vpnStateMonitor, userPlanManager = mockPlanManager, @@ -342,7 +347,7 @@ class ServerListUpdaterTests { val result1 = serverListUpdater.updateServers() assertEquals(successResult, result1.result) assertEquals(listOf("id1"), serverManager.allServers.map { it.serverId }) - assertEquals(firstUpdateTimestamp, serverManager.lastUpdateTimestamp) + assertEquals(firstUpdateTimestamp, serversDataManager.lastUpdateTimestamp) // Version will not change for the next call advanceTimeBy(5.minutes) @@ -352,7 +357,7 @@ class ServerListUpdaterTests { // 304 does not result in a call to setServers but will refresh timestamp. assertEquals(listOf("id1"), serverManager.allServers.map { it.serverId }) assertEquals(firstUpdateTimestamp, serverListUpdaterPrefs.serverListLastModified) - assertEquals(currentTime, serverManager.lastUpdateTimestamp) + assertEquals(currentTime, serversDataManager.lastUpdateTimestamp) // Make new version available fakeServerListV2Backend.serverLastModified = { currentTime } @@ -360,7 +365,7 @@ class ServerListUpdaterTests { val result3 = serverListUpdater.updateServers() assertEquals(successResult, result3.result) assertEquals(currentTime, serverListUpdaterPrefs.serverListLastModified) - assertEquals(currentTime, serverManager.lastUpdateTimestamp) + assertEquals(currentTime, serversDataManager.lastUpdateTimestamp) assertEquals(listOf("id2"), serverManager.allServers.map { it.serverId }) } diff --git a/app/src/test/java/com/protonvpn/app/vpn/ServersDataManagerTests.kt b/app/src/test/java/com/protonvpn/app/vpn/ServersDataManagerTests.kt index 865ba6cb1..24bc7f0ce 100644 --- a/app/src/test/java/com/protonvpn/app/vpn/ServersDataManagerTests.kt +++ b/app/src/test/java/com/protonvpn/app/vpn/ServersDataManagerTests.kt @@ -31,9 +31,11 @@ import com.protonvpn.test.shared.createServer import junit.framework.TestCase.assertEquals import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.ExperimentalCoroutinesApi +import kotlinx.coroutines.flow.first import kotlinx.coroutines.test.TestDispatcher import kotlinx.coroutines.test.TestScope import kotlinx.coroutines.test.UnconfinedTestDispatcher +import kotlinx.coroutines.test.currentTime import kotlinx.coroutines.test.resetMain import kotlinx.coroutines.test.runTest import kotlinx.coroutines.test.setMain @@ -62,6 +64,7 @@ class ServersDataManagerTests { TestDispatcherProvider(testDispatcher), store, fakeServerStateUpdater, + testScope::currentTime, ) } @@ -89,7 +92,7 @@ class ServersDataManagerTests { null, retainIDs = setOf("1", "2") ) - assertEquals(setOf("1", "2", "4"), manager.allServers.toIds()) + assertEquals(setOf("1", "2", "4"), manager.allServers().toIds()) } @Test @@ -105,18 +108,20 @@ class ServersDataManagerTests { val statusId = "status ID" manager.replaceServers(servers, statusId, emptySet()) - assertEquals(setOf("1", "3", "5"), manager.allServers.toIds()) - assertEquals(setOf("1"), manager.vpnCountries.find { it.flag == "PL" }?.serverList?.toIds()) - assertEquals(setOf("3"), manager.gateways.find { it.name() == "company" }?.serverList?.toIds()) - assertEquals(setOf("5"), manager.secureCoreExitCountries.find { it.flag == "PL" }?.serverList?.toIds()) + val afterReplace = manager.serverLists.first() + assertEquals(setOf("1", "3", "5"), afterReplace.allServers.toIds()) + assertEquals(setOf("1"), afterReplace.vpnCountries.find { it.flag == "PL" }?.serverList?.toIds()) + assertEquals(setOf("3"), afterReplace.gateways.find { it.name() == "company" }?.serverList?.toIds()) + assertEquals(setOf("5"), afterReplace.secureCoreExitCountries.find { it.flag == "PL" }?.serverList?.toIds()) fakeServerStateUpdater.mapsAllServers { it.copy(isVisible = true) } manager.updateBinaryLoads(statusId, ByteArray(0)) - assertEquals(setOf("1", "2", "3", "4", "5", "6"), manager.allServers.toIds()) - assertEquals(setOf("1", "2"), manager.vpnCountries.find { it.flag == "PL" }?.serverList?.toIds()) - assertEquals(setOf("3", "4"), manager.gateways.find { it.name() == "company" }?.serverList?.toIds()) - assertEquals(setOf("5", "6"), manager.secureCoreExitCountries.find { it.flag == "PL" }?.serverList?.toIds()) + val afterUpdate = manager.serverLists.first() + assertEquals(setOf("1", "2", "3", "4", "5", "6"), afterUpdate.allServers.toIds()) + assertEquals(setOf("1", "2"), afterUpdate.vpnCountries.find { it.flag == "PL" }?.serverList?.toIds()) + assertEquals(setOf("3", "4"), afterUpdate.gateways.find { it.name() == "company" }?.serverList?.toIds()) + assertEquals(setOf("5", "6"), afterUpdate.secureCoreExitCountries.find { it.flag == "PL" }?.serverList?.toIds()) } @Test @@ -130,11 +135,13 @@ class ServersDataManagerTests { manager.replaceServers(servers, statusIdCurrent, emptySet()) manager.updateBinaryLoads(statusIdOld, ByteArray(0)) - assertEquals(listOf(50f, 50f), manager.allServers.map { it.load }) + assertEquals(listOf(50f, 50f), manager.allServers().map { it.load }) manager.updateBinaryLoads(statusIdCurrent, ByteArray(0)) - assertEquals(listOf(newLoad, newLoad), manager.allServers.map { it.load }) + assertEquals(listOf(newLoad, newLoad), manager.allServers().map { it.load }) } + private suspend fun ServersDataManager.allServers() = serverLists.first().allServers + private fun Iterable.toIds(): Set = this.map { it.serverId }.toSet() } diff --git a/app/src/test/java/com/protonvpn/app/vpn/ServersStoreTests.kt b/app/src/test/java/com/protonvpn/app/vpn/ServersStoreTests.kt index 31361cc46..3c31efcb8 100644 --- a/app/src/test/java/com/protonvpn/app/vpn/ServersStoreTests.kt +++ b/app/src/test/java/com/protonvpn/app/vpn/ServersStoreTests.kt @@ -50,6 +50,7 @@ class ServersStoreTests { private val tmpFile = File("servers${FileObjectStore.TMP_SUFFIX}") private val serversStatusId = "statusId" private val servers = MockedServers.serverList + private val timestamp = 1234L private suspend fun createAndLoadServersStore(testFile: File = File("servers")) = ServersStore.create( @@ -80,31 +81,34 @@ class ServersStoreTests { fun `basic store, read and clear`() = testScope.runTest { val store = createAndLoadServersStore(testFile) - store.save(servers, serversStatusId) + store.save(servers, serversStatusId, timestamp) assertTrue(testFile.exists()) val store2 = createAndLoadServersStore(testFile) assertEquals(servers, store2.allServers) assertEquals(serversStatusId, store2.serversStatusId) + assertEquals(timestamp, store2.lastUpdateTimestamp) store.clear() val store3 = createAndLoadServersStore(testFile) assertEquals(emptyList(), store3.allServers) assertEquals(null, store3.serversStatusId) + assertEquals(0, store3.lastUpdateTimestamp) assertFalse(testFile.exists()) } @Test fun `recover from interrupted rename`() = testScope.runTest { val store = createAndLoadServersStore(testFile) - store.save(servers, serversStatusId) + store.save(servers, serversStatusId, timestamp) // Tmp write was successful but the rename failed testFile.renameTo(tmpFile) val store2 = createAndLoadServersStore(testFile) assertEquals(servers, store2.allServers) - assertEquals(serversStatusId, store.serversStatusId) + assertEquals(serversStatusId, store2.serversStatusId) + assertEquals(timestamp, store2.lastUpdateTimestamp) assertTrue(testFile.exists()) assertFalse(tmpFile.exists()) } @@ -112,7 +116,7 @@ class ServersStoreTests { @Test fun `recover from unfinished write`() = testScope.runTest { val store = createAndLoadServersStore(testFile) - store.save(servers, serversStatusId) + store.save(servers, serversStatusId, timestamp) // Tmp write was interrupted but test file exists tmpFile.writeText("corrupted") @@ -120,6 +124,7 @@ class ServersStoreTests { val store2 = createAndLoadServersStore(testFile) assertEquals(servers, store2.allServers) assertEquals(serversStatusId, store2.serversStatusId) + assertEquals(timestamp, store2.lastUpdateTimestamp) assertTrue(testFile.exists()) assertFalse(tmpFile.exists()) } @@ -127,7 +132,7 @@ class ServersStoreTests { @Test fun `recover from unfinished first save`() = testScope.runTest { val store = createAndLoadServersStore(testFile) - store.save(servers, serversStatusId) + store.save(servers, serversStatusId, timestamp) testFile.delete() tmpFile.writeText("corrupted") @@ -135,5 +140,6 @@ class ServersStoreTests { val store2 = createAndLoadServersStore(testFile) assertEquals(emptyList(), store2.allServers) assertEquals(null, store2.serversStatusId) + assertEquals(0, store2.lastUpdateTimestamp) } } diff --git a/shared-test-code/src/main/java/com/protonvpn/mocks/InMemoryServerManager.kt b/shared-test-code/src/main/java/com/protonvpn/mocks/InMemoryServerManager.kt index c0aa3a277..0c40c073e 100644 --- a/shared-test-code/src/main/java/com/protonvpn/mocks/InMemoryServerManager.kt +++ b/shared-test-code/src/main/java/com/protonvpn/mocks/InMemoryServerManager.kt @@ -38,6 +38,22 @@ import kotlinx.coroutines.test.currentTime import kotlinx.coroutines.test.runCurrent @OptIn(ExperimentalCoroutinesApi::class) +fun createInMemoryServersDataManager( + testScope: TestScope, + testDispatcherProvider: TestDispatcherProvider, + initialServers: List = emptyList(), + initialStatusId: LogicalsStatusId? = null, + updateWithBinaryStatus: UpdateServersWithBinaryStatus = FakeUpdateServersWithBinaryStatus(), +): ServersDataManager { + val serverStore = createInMemoryServersStore(initialServers, initialStatusId) + return ServersDataManager( + testDispatcherProvider, + serverStore, + updateWithBinaryStatus, + testScope::currentTime + ) +} + fun createInMemoryServerManager( testScope: TestScope, testDispatcherProvider: TestDispatcherProvider, @@ -45,22 +61,32 @@ fun createInMemoryServerManager( initialStatusId: LogicalsStatusId? = null, updateWithBinaryStatus: UpdateServersWithBinaryStatus = FakeUpdateServersWithBinaryStatus(), physicalUserCountry: UserCountryPhysical = createNoopUserCountry(), +) = + createInMemoryServerManager( + testScope = testScope, + serversDataManager = + createInMemoryServersDataManager( + testScope = testScope, + testDispatcherProvider = testDispatcherProvider, + initialServers = initialServers, + initialStatusId = initialStatusId, + updateWithBinaryStatus = updateWithBinaryStatus, + ), + physicalUserCountry = physicalUserCountry, + ) + +@OptIn(ExperimentalCoroutinesApi::class) +fun createInMemoryServerManager( + testScope: TestScope, + serversDataManager: ServersDataManager, + physicalUserCountry: UserCountryPhysical = createNoopUserCountry(), ): ServerManager { - val serverStore = createInMemoryServersStore(initialServers, initialStatusId) - val serversDataManager = ServersDataManager( - testDispatcherProvider, - serverStore, - updateWithBinaryStatus, - ) - val serverManager = ServerManager( - testScope.backgroundScope, - testScope::currentTime, - serversDataManager, - physicalUserCountry, - ) - testScope.launch { - serverManager.setServers(initialServers, initialStatusId) - } + val serverManager = + ServerManager( + testScope.backgroundScope, + serversDataManager, + physicalUserCountry, + ) testScope.runCurrent() return serverManager }