Expose server lists as flow from ServersDataManager [VPNAND-2467].

This commit is contained in:
Marcin Simonides
2026-02-10 13:01:38 +01:00
committed by MargeBot
parent 6c14f63261
commit e5dd9b8e4a
11 changed files with 211 additions and 140 deletions

View File

@@ -57,8 +57,6 @@ class ServerManager2 @Inject constructor(
emitAll(serverManager.serverListVersion)
}
/** The first value is emitted before servers are loaded */
val isDownloadedAtLeastOnceFlow = serverManager.isDownloadedAtLeastOnceFlow
/** The first value is emitted before servers are loaded */
val hasAnyCountryFlow = serverManager.hasCountriesFlow
/** The first value is emitted before servers are loaded */

View File

@@ -20,13 +20,17 @@
package com.protonvpn.android.servers
import com.protonvpn.android.concurrency.VpnDispatcherProvider
import com.protonvpn.android.servers.api.ConnectingDomain
import com.protonvpn.android.di.WallClock
import com.protonvpn.android.models.vpn.GatewayGroup
import com.protonvpn.android.servers.api.LoadUpdate
import com.protonvpn.android.models.vpn.VpnCountry
import com.protonvpn.android.servers.api.ConnectingDomain
import com.protonvpn.android.servers.api.LoadUpdate
import com.protonvpn.android.servers.api.LogicalsStatusId
import com.protonvpn.android.utils.replace
import kotlinx.coroutines.async
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.filterNotNull
import kotlinx.coroutines.sync.Mutex
import kotlinx.coroutines.sync.withLock
import kotlinx.coroutines.withContext
@@ -38,43 +42,46 @@ class ServersDataManager @Inject constructor(
private val dispatcherProvider: VpnDispatcherProvider,
private val serversStore: ServersStore,
private val updateServersWithBinaryStatus: UpdateServersWithBinaryStatus,
@WallClock private val wallClock: () -> Long,
) {
private data class ServerLists(
data class ServerLists(
val allServers: List<Server>,
val allServersByScore: List<Server>,
val vpnCountries: List<VpnCountry>,
val secureCoreExitCountries: List<VpnCountry>,
val gateways: List<GatewayGroup>,
val statusId: LogicalsStatusId? = null,
)
) {
companion object {
val Empty = ServerLists(
allServers = emptyList(),
allServersByScore = emptyList(),
vpnCountries = emptyList(),
secureCoreExitCountries = emptyList(),
gateways = emptyList(),
statusId = null,
)
}
}
private val serverListsFlow = MutableStateFlow<ServerLists?>(null)
val serverLists: Flow<ServerLists> = serverListsFlow.filterNotNull()
var lastUpdateTimestamp: Long = 0
private set
private data class UpdateResult(
val statusId: LogicalsStatusId?,
val servers: List<Server>,
val serverListUpdateTimestamp: Long?,
)
// Protect all modifications with the mutex. Use updateWithMutex for common cases.
private val updateMutex = Mutex()
private var serverLists = ServerLists(
allServers = emptyList(),
allServersByScore = emptyList(),
vpnCountries = emptyList(),
secureCoreExitCountries = emptyList(),
gateways = emptyList(),
statusId = null,
)
val statusId: LogicalsStatusId? get() = serverLists.statusId
val allServers: List<Server> get() = serverLists.allServers
val allServersByScore: List<Server> get() = serverLists.allServersByScore
val vpnCountries: List<VpnCountry> get() = serverLists.vpnCountries
val secureCoreExitCountries: List<VpnCountry> get() = serverLists.secureCoreExitCountries
val gateways: List<GatewayGroup> get() = serverLists.gateways
// Load servers from storage. Returns true if servers were loaded successfully.
suspend fun load(): Boolean {
val loaded = serversStore.load()
updateWithMutex(saveToStorage = false) { with(serversStore) { UpdateResult(serversStatusId, allServers) } }
updateWithMutex(saveToStorage = false) { with(serversStore) { UpdateResult(serversStatusId, allServers, lastUpdateTimestamp) } }
return loaded
}
@@ -87,17 +94,18 @@ class ServersDataManager @Inject constructor(
if (server.serverId in retainIDs)
missingServerIDs.remove(server.serverId)
}
val retainedServers = allServers.filter { it.serverId in missingServerIDs }
UpdateResult(newStatusId, serverList + retainedServers)
val retainedServers = currentServers().allServers.filter { it.serverId in missingServerIDs }
UpdateResult(newStatusId, serverList + retainedServers, wallClock())
}
} else {
UpdateResult(newStatusId, serverList)
UpdateResult(newStatusId, serverList, wallClock())
}
}
}
suspend fun updateServerDomainStatus(connectingDomain: ConnectingDomain) {
updateWithMutex {
val allServers = currentServers().allServers
val updatedServers = buildList(allServers.size) {
allServers.forEach { currentServer ->
val server = if (currentServer.connectingDomains.any { it.id == connectingDomain.id }) {
@@ -112,15 +120,15 @@ class ServersDataManager @Inject constructor(
add(server)
}
}
UpdateResult(statusId, updatedServers)
UpdateResult(currentServers().statusId, updatedServers, serverListUpdateTimestamp = null)
}
}
suspend fun updateLoads(loadsList: List<LoadUpdate>) {
updateWithMutex {
val loadsMap: Map<String, LoadUpdate> = loadsList.associateBy { it.id }
val updatedServers = buildList(allServers.size) {
allServers.forEach { currentServer ->
val updatedServers = buildList(currentServers().allServers.size) {
currentServers().allServers.forEach { currentServer ->
val newValues = loadsMap[currentServer.serverId]
val server = if (newValues != null) {
// Status update doesn't include physical servers, it's not safe to go from
@@ -137,15 +145,23 @@ class ServersDataManager @Inject constructor(
add(server)
}
}
UpdateResult(statusId, updatedServers)
UpdateResult(currentServers().statusId, updatedServers, serverListUpdateTimestamp = null)
}
}
suspend fun updateBinaryLoads(statusId: LogicalsStatusId, statusData: ByteArray) {
updateWithMutex {
if (statusId != this.statusId) return@updateWithMutex null
if (statusId != currentServers().statusId) return@updateWithMutex null
val updatedServers = updateServersWithBinaryStatus(serversStore.allServers, statusData)
updatedServers?.let { UpdateResult(statusId, updatedServers) }
updatedServers?.let {
UpdateResult(statusId, updatedServers, serverListUpdateTimestamp = null)
}
}
}
suspend fun updateLastUpdateTimestamp(timestamp: Long = wallClock()) {
updateWithMutex {
UpdateResult(currentServers().statusId, currentServers().allServers, timestamp)
}
}
@@ -153,11 +169,12 @@ class ServersDataManager @Inject constructor(
updateWithMutex {
withContext(dispatcherProvider.Comp) {
UpdateResult(
statusId,
allServers.toMutableList().apply {
currentServers().statusId,
currentServers().allServers.toMutableList().apply {
removeIf { it.serverId == server.serverId }
add(server)
}
},
wallClock()
)
}
}
@@ -168,25 +185,32 @@ class ServersDataManager @Inject constructor(
updateBlock: suspend () -> UpdateResult?,
) {
updateMutex.withLock {
val updateResult: Pair<List<Server>, ServerLists>? = withContext(dispatcherProvider.Comp) {
val updateResult: Triple<List<Server>, ServerLists, Long?>? = withContext(dispatcherProvider.Comp) {
val update = updateBlock() ?: return@withContext null
val newServers = update.servers.filter { it.isVisible }
val groupedServers = async { updateServerLists(newServers, update.statusId) }
val sortedServers = async { newServers.sortedBy { it.score } }
val newServerLists = groupedServers.await()
.copy(allServersByScore = sortedServers.await())
update.servers to newServerLists
Triple(update.servers,newServerLists, update.serverListUpdateTimestamp)
}
if (updateResult != null) {
val (allServers, newServerLists) = updateResult
if (saveToStorage) {
serversStore.save(allServers, newServerLists.statusId)
val (allServers, newServerLists, updateTimestamp) = updateResult
if (updateTimestamp != null) {
lastUpdateTimestamp = updateTimestamp
}
serverLists = newServerLists
if (saveToStorage) {
serversStore.save(allServers, newServerLists.statusId, lastUpdateTimestamp)
}
serverListsFlow.value = newServerLists
}
}
}
// Use only when protected by the updateMutex.
private fun currentServers() = serverListsFlow.value ?: ServerLists.Empty
companion object {
private fun updateServerLists(newServerList: List<Server>, statusId: LogicalsStatusId?): ServerLists {
fun MutableMap<String, MutableList<Server>>.addServer(

View File

@@ -38,11 +38,14 @@ class ServersStore(
private set
var allServers: List<Server> = emptyList()
private set
var lastUpdateTimestamp: Long = 0
private set
suspend fun load(): Boolean {
val data = store.read()
if (data != null) {
serversStatusId = data.statusFileId
lastUpdateTimestamp = data.updateTimestamp
allServers = if (data.allServers.isEmpty() && data.vpnCountries.isNotEmpty()) {
extractServers(data.vpnCountries, data.secureCoreEntryCountries, data.secureCoreExitCountries)
} else {
@@ -52,10 +55,11 @@ class ServersStore(
return data != null
}
fun save(newServers: List<Server>, newStatusId: LogicalsStatusId?) {
fun save(newServers: List<Server>, newStatusId: LogicalsStatusId?, updateTimestamp: Long) {
serversStatusId = newStatusId
allServers = newServers
val data = ServersSerializationData(allServers, serversStatusId)
lastUpdateTimestamp = updateTimestamp
val data = ServersSerializationData(allServers, serversStatusId, updateTimestamp)
store.store(data)
}
@@ -94,6 +98,7 @@ class ServersStore(
class ServersSerializationData(
val allServers: List<Server> = emptyList(),
val statusFileId: String? = null,
val updateTimestamp: Long = 0,
// Deprecated, used only for migration.
val vpnCountries: List<VpnCountry> = emptyList(),

View File

@@ -48,8 +48,8 @@ import javax.inject.Inject
class UpdateServerListFromApi @Inject constructor(
private val api: ProtonApiRetroFit,
private val dispatcherProvider: DispatcherProvider,
@WallClock private val wallClock: () -> Long,
private val serverManager: ServerManager,
private val serversDataManager: ServersDataManager,
private val prefs: ServerListUpdaterPrefs,
private val updateWithBinaryStatus: UpdateServersWithBinaryStatus,
private val binaryServerStatusEnabled: IsBinaryServerStatusEnabled,
@@ -137,7 +137,7 @@ class UpdateServerListFromApi @Inject constructor(
is FetchResult.NewServers,
is FetchResult.NotModified -> {
serverManager.updateTimestamp()
serversDataManager.updateLastUpdateTimestamp()
PeriodicActionResult(Result.Success, isSuccess = true)
}
}

View File

@@ -39,6 +39,7 @@ import com.protonvpn.android.logging.LogCategory
import com.protonvpn.android.logging.ProtonLogger
import com.protonvpn.android.models.vpn.UserLocation
import com.protonvpn.android.servers.IsBinaryServerStatusEnabled
import com.protonvpn.android.servers.ServersDataManager
import com.protonvpn.android.servers.UpdateServerListFromApi
import com.protonvpn.android.servers.api.ServersCountResponse
import com.protonvpn.android.utils.ServerManager
@@ -86,6 +87,7 @@ class ServerListUpdater @Inject constructor(
private val scope: CoroutineScope,
private val api: ProtonApiRetroFit,
private val serverManager: ServerManager,
private val serversDataManager: ServersDataManager,
private val currentUser: CurrentUser,
private val vpnStateMonitor: VpnStateMonitor,
userPlanManager: UserPlanManager,
@@ -118,7 +120,7 @@ class ServerListUpdater @Inject constructor(
)
suspend fun needsUpdate() = serverManager.needsUpdate() ||
wallClock() - serverManager.lastUpdateTimestamp >= 4 * remoteConfig.first().foregroundDelayMs
wallClock() - serversDataManager.lastUpdateTimestamp >= 4 * remoteConfig.first().foregroundDelayMs
init {
migrateIpAddress()

View File

@@ -23,7 +23,6 @@ import com.protonvpn.android.BuildConfig
import com.protonvpn.android.appconfig.UserCountryPhysical
import com.protonvpn.android.auth.data.VpnUser
import com.protonvpn.android.auth.data.hasAccessToServer
import com.protonvpn.android.di.WallClock
import com.protonvpn.android.excludedlocations.ExcludedLocations
import com.protonvpn.android.logging.ProtonLogger
import com.protonvpn.android.models.profiles.Profile
@@ -40,6 +39,7 @@ import com.protonvpn.android.redesign.vpn.isExcluded
import com.protonvpn.android.redesign.vpn.satisfiesFeatures
import com.protonvpn.android.servers.Server
import com.protonvpn.android.servers.ServersDataManager
import com.protonvpn.android.servers.ServersDataManager.ServerLists
import com.protonvpn.android.servers.ServersResult
import com.protonvpn.android.servers.api.ConnectingDomain
import com.protonvpn.android.servers.api.LoadUpdate
@@ -47,9 +47,13 @@ import com.protonvpn.android.servers.api.LogicalsStatusId
import com.protonvpn.android.vpn.ProtocolSelection
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.SharingStarted
import kotlinx.coroutines.flow.distinctUntilChanged
import kotlinx.coroutines.flow.first
import kotlinx.coroutines.flow.launchIn
import kotlinx.coroutines.flow.map
import kotlinx.coroutines.flow.onEach
import kotlinx.coroutines.flow.stateIn
import kotlinx.coroutines.launch
import kotlinx.serialization.Serializable
import javax.inject.Inject
@@ -58,6 +62,7 @@ import javax.inject.Singleton
@Serializable
data class ServerManagerState(
val serverListAppVersionCode: Int = 0,
@Deprecated("The timestamp is now stored by ServersDataManager.")
val lastUpdateTimestamp: Long = 0L,
/** Can be checked even before servers are loaded from storage */
@@ -68,9 +73,8 @@ data class ServerManagerState(
@Deprecated("Use ServerManager2 in new code")
@Singleton
class ServerManager @Inject constructor(
private val mainScope: CoroutineScope,
@WallClock private val wallClock: () -> Long,
val serversData: ServersDataManager,
mainScope: CoroutineScope,
val serversDataManager: ServersDataManager,
val physicalUserCountry: UserCountryPhysical,
) {
private var savedState: ServerManagerState
@@ -82,23 +86,25 @@ class ServerManager @Inject constructor(
// combine to react to updates.
val serverListVersion = MutableStateFlow(0)
val lastUpdateTimestamp get() = savedState.lastUpdateTimestamp
val isDownloadedAtLeastOnce get() = serversDataManager.lastUpdateTimestamp > 0
/** Can be checked even before servers are loaded from storage */
val isDownloadedAtLeastOnce get() = savedState.lastUpdateTimestamp > 0
val isDownloadedAtLeastOnceFlow = serverListVersion.map { isDownloadedAtLeastOnce }.distinctUntilChanged()
val hasCountriesFlow = serverListVersion.map { savedState.hasCountries }.distinctUntilChanged()
val hasGatewaysFlow = serverListVersion.map { savedState.hasGateways }.distinctUntilChanged()
// The cached values are to be used only by legacy code.
private val serversDataCachedFlow =
serversDataManager.serverLists.stateIn(mainScope, SharingStarted.Eagerly, ServerLists.Empty)
private val serversDataCached get() = serversDataCachedFlow.value // May be empty.
suspend fun needsUpdate(): Boolean {
ensureLoaded()
return savedState.lastUpdateTimestamp == 0L || serversData.allServers.isEmpty() ||
val serversData = serversDataManager.serverLists.first()
return serversDataManager.lastUpdateTimestamp == 0L || serversData.allServers.isEmpty() ||
!haveWireGuardSupport() || savedState.serverListAppVersionCode < BuildConfig.VERSION_CODE
}
val logicalsStatusId get() = serversData.statusId
val allServers get() = serversData.allServers
val allServersByScore get() = serversData.allServersByScore
val logicalsStatusId get() = serversDataCached.statusId
val allServers get() = serversDataCached.allServers
val allServersByScore get() = serversDataCached.allServersByScore
val freeCountries
get() = getVpnCountries()
@@ -114,7 +120,7 @@ class ServerManager @Inject constructor(
}
mainScope.launch {
val loaded = serversData.load()
val loaded = serversDataManager.load()
if (!loaded) {
// We had servers saved but failed to load them, reset state.
updateAndSave(
@@ -124,12 +130,26 @@ class ServerManager @Inject constructor(
hasCountries = false,
)
)
} else if (
serversDataManager.lastUpdateTimestamp == 0L &&
savedState.lastUpdateTimestamp > 0
) {
serversDataManager.updateLastUpdateTimestamp(savedState.lastUpdateTimestamp)
updateAndSave(savedState.copy(lastUpdateTimestamp = 0L))
}
// Notify of loaded state and update after everything has been updated.
isLoaded.value = true
onServersUpdate()
}
serversDataManager.serverLists.onEach { serversData ->
updateAndSave(
savedState.copy(
serverListAppVersionCode = BuildConfig.VERSION_CODE,
hasGateways = serversData.gateways.isNotEmpty(),
hasCountries = serversData.vpnCountries.isNotEmpty(),
)
)
serverListVersion.value++
}.launchIn(mainScope)
}
suspend fun ensureLoaded() {
@@ -137,18 +157,14 @@ class ServerManager @Inject constructor(
}
fun getExitCountries(secureCore: Boolean) = if (secureCore)
serversData.secureCoreExitCountries else serversData.vpnCountries
private fun onServersUpdate() {
++serverListVersion.value
}
serversDataCached.secureCoreExitCountries else serversDataCached.vpnCountries
override fun toString(): String {
val lastUpdateTimestampLog = savedState.lastUpdateTimestamp
val lastUpdateTimestampLog = serversDataManager.lastUpdateTimestamp
.takeIf { it != 0L }
?.let { ProtonLogger.formatTime(it) }
return "vpnCountries: ${serversData.vpnCountries.size} gateways: ${serversData.gateways.size}" +
" exit: ${serversData.secureCoreExitCountries.size} " +
return "vpnCountries: ${serversDataCached.vpnCountries.size} gateways: ${serversDataCached.gateways.size}" +
" exit: ${serversDataCached.secureCoreExitCountries.size} " +
"ServerManager Updated: $lastUpdateTimestampLog"
}
@@ -190,55 +206,35 @@ class ServerManager @Inject constructor(
retainIDs: Set<String> = emptySet()
) {
ensureLoaded()
serversData.replaceServers(serverList, statusId, retainIDs)
updateAndSave(
savedState.copy(
lastUpdateTimestamp = wallClock(),
serverListAppVersionCode = BuildConfig.VERSION_CODE,
hasGateways = serversData.gateways.isNotEmpty(),
hasCountries = serversData.vpnCountries.isNotEmpty(),
)
)
onServersUpdate()
}
fun updateTimestamp() {
updateAndSave(
savedState.copy(lastUpdateTimestamp = wallClock())
)
serversDataManager.replaceServers(serverList, statusId, retainIDs)
}
suspend fun updateServerDomainStatus(connectingDomain: ConnectingDomain) {
ensureLoaded()
serversData.updateServerDomainStatus(connectingDomain)
onServersUpdate()
serversDataManager.updateServerDomainStatus(connectingDomain)
}
suspend fun updateLoads(loadsList: List<LoadUpdate>) {
ensureLoaded()
serversData.updateLoads(loadsList)
onServersUpdate()
serversDataManager.updateLoads(loadsList)
}
suspend fun updateBinaryLoads(statusId: LogicalsStatusId, loads: ByteArray) {
ensureLoaded()
serversData.updateBinaryLoads(statusId, loads)
onServersUpdate()
serversDataManager.updateBinaryLoads(statusId, loads)
}
suspend fun updateOrAddServer(server: Server) {
ensureLoaded()
serversData.updateOrAddServer(server)
onServersUpdate()
serversDataManager.updateOrAddServer(server)
}
fun getServerById(id: String) =
allServers.firstOrNull { it.serverId == id } ?: guestHoleServers?.firstOrNull { it.serverId == id }
fun getVpnCountries(): List<VpnCountry> = serversData.vpnCountries.sortedByLocaleAware { it.countryName }
fun getVpnCountries(): List<VpnCountry> = serversDataCached.vpnCountries.sortedByLocaleAware { it.countryName }
fun getGateways(): List<GatewayGroup> = serversData.gateways
fun getGateways(): List<GatewayGroup> = serversDataCached.gateways
@Deprecated("Use the suspending getVpnExitCountry from ServerManager2")
fun getVpnExitCountry(countryCode: String, secureCoreCountry: Boolean): VpnCountry? =
@@ -283,7 +279,7 @@ class ServerManager @Inject constructor(
}
fun getSecureCoreExitCountries(): List<VpnCountry> =
serversData.secureCoreExitCountries.sortedByLocaleAware { it.countryName }
serversDataCached.secureCoreExitCountries.sortedByLocaleAware { it.countryName }
@Deprecated("Use getServerForConnectIntent")
fun getServerForProfile(
@@ -475,7 +471,7 @@ class ServerManager @Inject constructor(
}
private fun haveWireGuardSupport() =
serversData.allServers.any { server -> server.connectingDomains.any { it.publicKeyX25519 != null } }
serversDataCached.allServers.any { server -> server.connectingDomains.any { it.publicKeyX25519 != null } }
private fun <T> Server?.handleServersResult(
onServersResult: (ServersResult) -> T,

View File

@@ -35,6 +35,7 @@ import kotlinx.coroutines.test.StandardTestDispatcher
import kotlinx.coroutines.test.TestDispatcher
import kotlinx.coroutines.test.TestScope
import kotlinx.coroutines.test.currentTime
import kotlinx.coroutines.test.runCurrent
import kotlinx.coroutines.test.runTest
import org.junit.Assert.assertEquals
import org.junit.Assert.assertFalse
@@ -62,6 +63,7 @@ class ServerManagerStorageTests {
testDispatcherProvider,
createInMemoryServersStore(emptyList(), ""),
FakeUpdateServersWithBinaryStatus(),
testScope::currentTime,
)
}
@@ -75,14 +77,14 @@ class ServerManagerStorageTests {
}
val serverManager = ServerManager(
mainScope = testScope.backgroundScope,
wallClock = testScope::currentTime,
serversData = serversDataManager,
createNoopUserCountry(),
serversDataManager = serversDataManager,
physicalUserCountry = createNoopUserCountry(),
)
assertTrue(serverManager.isDownloadedAtLeastOnce)
assertEquals(1769690569069L, serverManager.lastUpdateTimestamp)
assertTrue(serverManager.hasCountriesFlow.first())
assertFalse(serverManager.hasGatewaysFlow.first())
serverManager.ensureLoaded()
assertTrue(serverManager.isDownloadedAtLeastOnce)
assertEquals(1769690569069L, serversDataManager.lastUpdateTimestamp)
}
}

View File

@@ -30,6 +30,7 @@ import com.protonvpn.android.models.vpn.data.LogicalsMetadata
import com.protonvpn.android.servers.FakeIsBinaryServerStatusFeatureFlagEnabled
import com.protonvpn.android.servers.IsBinaryServerStatusEnabled
import com.protonvpn.android.servers.Server
import com.protonvpn.android.servers.ServersDataManager
import com.protonvpn.android.servers.UpdateServerListFromApi
import com.protonvpn.android.servers.api.LogicalServer
import com.protonvpn.android.servers.api.LogicalServerV1
@@ -49,6 +50,7 @@ import com.protonvpn.android.vpn.usecases.FakeServerListTruncationEnabled
import com.protonvpn.android.vpn.usecases.GetTruncationMustHaveIDs
import com.protonvpn.mocks.FakeUpdateServersWithBinaryStatus
import com.protonvpn.mocks.createInMemoryServerManager
import com.protonvpn.mocks.createInMemoryServersDataManager
import com.protonvpn.test.shared.MockSharedPreference
import com.protonvpn.test.shared.MockSharedPreferencesProvider
import com.protonvpn.test.shared.MockedServers
@@ -128,6 +130,7 @@ class ServerListUpdaterTests {
private lateinit var vpnStateMonitor: VpnStateMonitor
private lateinit var mustHaveIDs: Set<String>
private lateinit var binaryStatusFfEnabled: MutableStateFlow<Boolean>
private lateinit var serversDataManager: ServersDataManager
private lateinit var serverManager: ServerManager
private lateinit var truncationEnabled: MutableStateFlow<Boolean>
private lateinit var runWhileGettingServerList: () -> Unit
@@ -186,7 +189,8 @@ class ServerListUpdaterTests {
mustHaveIDs = emptySet()
binaryStatusFfEnabled = MutableStateFlow(false)
serverManager = createInMemoryServerManager(testScope, TestDispatcherProvider(testDispatcher), initialServers = emptyList())
serversDataManager = createInMemoryServersDataManager(testScope, TestDispatcherProvider(testDispatcher))
serverManager = createInMemoryServerManager(testScope, serversDataManager)
truncationEnabled = MutableStateFlow(true)
val getNetZone = GetNetZone(serverListUpdaterPrefs)
val serverListTruncationFF = FakeServerListTruncationEnabled(truncationEnabled)
@@ -198,8 +202,8 @@ class ServerListUpdaterTests {
val updateServerListFromApi = UpdateServerListFromApi(
mockApi,
TestDispatcherProvider(testDispatcher),
testScope::currentTime,
serverManager,
serversDataManager,
serverListUpdaterPrefs,
fakeUpdateWithBinaryStatus,
binaryServerStatusEnabled,
@@ -210,6 +214,7 @@ class ServerListUpdaterTests {
scope = testScope.backgroundScope,
api = mockApi,
serverManager = serverManager,
serversDataManager = serversDataManager,
currentUser = mockCurrentUser,
vpnStateMonitor = vpnStateMonitor,
userPlanManager = mockPlanManager,
@@ -342,7 +347,7 @@ class ServerListUpdaterTests {
val result1 = serverListUpdater.updateServers()
assertEquals(successResult, result1.result)
assertEquals(listOf("id1"), serverManager.allServers.map { it.serverId })
assertEquals(firstUpdateTimestamp, serverManager.lastUpdateTimestamp)
assertEquals(firstUpdateTimestamp, serversDataManager.lastUpdateTimestamp)
// Version will not change for the next call
advanceTimeBy(5.minutes)
@@ -352,7 +357,7 @@ class ServerListUpdaterTests {
// 304 does not result in a call to setServers but will refresh timestamp.
assertEquals(listOf("id1"), serverManager.allServers.map { it.serverId })
assertEquals(firstUpdateTimestamp, serverListUpdaterPrefs.serverListLastModified)
assertEquals(currentTime, serverManager.lastUpdateTimestamp)
assertEquals(currentTime, serversDataManager.lastUpdateTimestamp)
// Make new version available
fakeServerListV2Backend.serverLastModified = { currentTime }
@@ -360,7 +365,7 @@ class ServerListUpdaterTests {
val result3 = serverListUpdater.updateServers()
assertEquals(successResult, result3.result)
assertEquals(currentTime, serverListUpdaterPrefs.serverListLastModified)
assertEquals(currentTime, serverManager.lastUpdateTimestamp)
assertEquals(currentTime, serversDataManager.lastUpdateTimestamp)
assertEquals(listOf("id2"), serverManager.allServers.map { it.serverId })
}

View File

@@ -31,9 +31,11 @@ import com.protonvpn.test.shared.createServer
import junit.framework.TestCase.assertEquals
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.flow.first
import kotlinx.coroutines.test.TestDispatcher
import kotlinx.coroutines.test.TestScope
import kotlinx.coroutines.test.UnconfinedTestDispatcher
import kotlinx.coroutines.test.currentTime
import kotlinx.coroutines.test.resetMain
import kotlinx.coroutines.test.runTest
import kotlinx.coroutines.test.setMain
@@ -62,6 +64,7 @@ class ServersDataManagerTests {
TestDispatcherProvider(testDispatcher),
store,
fakeServerStateUpdater,
testScope::currentTime,
)
}
@@ -89,7 +92,7 @@ class ServersDataManagerTests {
null,
retainIDs = setOf("1", "2")
)
assertEquals(setOf("1", "2", "4"), manager.allServers.toIds())
assertEquals(setOf("1", "2", "4"), manager.allServers().toIds())
}
@Test
@@ -105,18 +108,20 @@ class ServersDataManagerTests {
val statusId = "status ID"
manager.replaceServers(servers, statusId, emptySet())
assertEquals(setOf("1", "3", "5"), manager.allServers.toIds())
assertEquals(setOf("1"), manager.vpnCountries.find { it.flag == "PL" }?.serverList?.toIds())
assertEquals(setOf("3"), manager.gateways.find { it.name() == "company" }?.serverList?.toIds())
assertEquals(setOf("5"), manager.secureCoreExitCountries.find { it.flag == "PL" }?.serverList?.toIds())
val afterReplace = manager.serverLists.first()
assertEquals(setOf("1", "3", "5"), afterReplace.allServers.toIds())
assertEquals(setOf("1"), afterReplace.vpnCountries.find { it.flag == "PL" }?.serverList?.toIds())
assertEquals(setOf("3"), afterReplace.gateways.find { it.name() == "company" }?.serverList?.toIds())
assertEquals(setOf("5"), afterReplace.secureCoreExitCountries.find { it.flag == "PL" }?.serverList?.toIds())
fakeServerStateUpdater.mapsAllServers { it.copy(isVisible = true) }
manager.updateBinaryLoads(statusId, ByteArray(0))
assertEquals(setOf("1", "2", "3", "4", "5", "6"), manager.allServers.toIds())
assertEquals(setOf("1", "2"), manager.vpnCountries.find { it.flag == "PL" }?.serverList?.toIds())
assertEquals(setOf("3", "4"), manager.gateways.find { it.name() == "company" }?.serverList?.toIds())
assertEquals(setOf("5", "6"), manager.secureCoreExitCountries.find { it.flag == "PL" }?.serverList?.toIds())
val afterUpdate = manager.serverLists.first()
assertEquals(setOf("1", "2", "3", "4", "5", "6"), afterUpdate.allServers.toIds())
assertEquals(setOf("1", "2"), afterUpdate.vpnCountries.find { it.flag == "PL" }?.serverList?.toIds())
assertEquals(setOf("3", "4"), afterUpdate.gateways.find { it.name() == "company" }?.serverList?.toIds())
assertEquals(setOf("5", "6"), afterUpdate.secureCoreExitCountries.find { it.flag == "PL" }?.serverList?.toIds())
}
@Test
@@ -130,11 +135,13 @@ class ServersDataManagerTests {
manager.replaceServers(servers, statusIdCurrent, emptySet())
manager.updateBinaryLoads(statusIdOld, ByteArray(0))
assertEquals(listOf(50f, 50f), manager.allServers.map { it.load })
assertEquals(listOf(50f, 50f), manager.allServers().map { it.load })
manager.updateBinaryLoads(statusIdCurrent, ByteArray(0))
assertEquals(listOf(newLoad, newLoad), manager.allServers.map { it.load })
assertEquals(listOf(newLoad, newLoad), manager.allServers().map { it.load })
}
private suspend fun ServersDataManager.allServers() = serverLists.first().allServers
private fun Iterable<Server>.toIds(): Set<String> = this.map { it.serverId }.toSet()
}

View File

@@ -50,6 +50,7 @@ class ServersStoreTests {
private val tmpFile = File("servers${FileObjectStore.TMP_SUFFIX}")
private val serversStatusId = "statusId"
private val servers = MockedServers.serverList
private val timestamp = 1234L
private suspend fun createAndLoadServersStore(testFile: File = File("servers")) =
ServersStore.create(
@@ -80,31 +81,34 @@ class ServersStoreTests {
fun `basic store, read and clear`() = testScope.runTest {
val store = createAndLoadServersStore(testFile)
store.save(servers, serversStatusId)
store.save(servers, serversStatusId, timestamp)
assertTrue(testFile.exists())
val store2 = createAndLoadServersStore(testFile)
assertEquals(servers, store2.allServers)
assertEquals(serversStatusId, store2.serversStatusId)
assertEquals(timestamp, store2.lastUpdateTimestamp)
store.clear()
val store3 = createAndLoadServersStore(testFile)
assertEquals(emptyList(), store3.allServers)
assertEquals(null, store3.serversStatusId)
assertEquals(0, store3.lastUpdateTimestamp)
assertFalse(testFile.exists())
}
@Test
fun `recover from interrupted rename`() = testScope.runTest {
val store = createAndLoadServersStore(testFile)
store.save(servers, serversStatusId)
store.save(servers, serversStatusId, timestamp)
// Tmp write was successful but the rename failed
testFile.renameTo(tmpFile)
val store2 = createAndLoadServersStore(testFile)
assertEquals(servers, store2.allServers)
assertEquals(serversStatusId, store.serversStatusId)
assertEquals(serversStatusId, store2.serversStatusId)
assertEquals(timestamp, store2.lastUpdateTimestamp)
assertTrue(testFile.exists())
assertFalse(tmpFile.exists())
}
@@ -112,7 +116,7 @@ class ServersStoreTests {
@Test
fun `recover from unfinished write`() = testScope.runTest {
val store = createAndLoadServersStore(testFile)
store.save(servers, serversStatusId)
store.save(servers, serversStatusId, timestamp)
// Tmp write was interrupted but test file exists
tmpFile.writeText("corrupted")
@@ -120,6 +124,7 @@ class ServersStoreTests {
val store2 = createAndLoadServersStore(testFile)
assertEquals(servers, store2.allServers)
assertEquals(serversStatusId, store2.serversStatusId)
assertEquals(timestamp, store2.lastUpdateTimestamp)
assertTrue(testFile.exists())
assertFalse(tmpFile.exists())
}
@@ -127,7 +132,7 @@ class ServersStoreTests {
@Test
fun `recover from unfinished first save`() = testScope.runTest {
val store = createAndLoadServersStore(testFile)
store.save(servers, serversStatusId)
store.save(servers, serversStatusId, timestamp)
testFile.delete()
tmpFile.writeText("corrupted")
@@ -135,5 +140,6 @@ class ServersStoreTests {
val store2 = createAndLoadServersStore(testFile)
assertEquals(emptyList(), store2.allServers)
assertEquals(null, store2.serversStatusId)
assertEquals(0, store2.lastUpdateTimestamp)
}
}

View File

@@ -38,6 +38,22 @@ import kotlinx.coroutines.test.currentTime
import kotlinx.coroutines.test.runCurrent
@OptIn(ExperimentalCoroutinesApi::class)
fun createInMemoryServersDataManager(
testScope: TestScope,
testDispatcherProvider: TestDispatcherProvider,
initialServers: List<Server> = emptyList(),
initialStatusId: LogicalsStatusId? = null,
updateWithBinaryStatus: UpdateServersWithBinaryStatus = FakeUpdateServersWithBinaryStatus(),
): ServersDataManager {
val serverStore = createInMemoryServersStore(initialServers, initialStatusId)
return ServersDataManager(
testDispatcherProvider,
serverStore,
updateWithBinaryStatus,
testScope::currentTime
)
}
fun createInMemoryServerManager(
testScope: TestScope,
testDispatcherProvider: TestDispatcherProvider,
@@ -45,22 +61,32 @@ fun createInMemoryServerManager(
initialStatusId: LogicalsStatusId? = null,
updateWithBinaryStatus: UpdateServersWithBinaryStatus = FakeUpdateServersWithBinaryStatus(),
physicalUserCountry: UserCountryPhysical = createNoopUserCountry(),
) =
createInMemoryServerManager(
testScope = testScope,
serversDataManager =
createInMemoryServersDataManager(
testScope = testScope,
testDispatcherProvider = testDispatcherProvider,
initialServers = initialServers,
initialStatusId = initialStatusId,
updateWithBinaryStatus = updateWithBinaryStatus,
),
physicalUserCountry = physicalUserCountry,
)
@OptIn(ExperimentalCoroutinesApi::class)
fun createInMemoryServerManager(
testScope: TestScope,
serversDataManager: ServersDataManager,
physicalUserCountry: UserCountryPhysical = createNoopUserCountry(),
): ServerManager {
val serverStore = createInMemoryServersStore(initialServers, initialStatusId)
val serversDataManager = ServersDataManager(
testDispatcherProvider,
serverStore,
updateWithBinaryStatus,
)
val serverManager = ServerManager(
testScope.backgroundScope,
testScope::currentTime,
serversDataManager,
physicalUserCountry,
)
testScope.launch {
serverManager.setServers(initialServers, initialStatusId)
}
val serverManager =
ServerManager(
testScope.backgroundScope,
serversDataManager,
physicalUserCountry,
)
testScope.runCurrent()
return serverManager
}