Connect Navigation and FragmentManager's state save/restore

Update FragmentNavigator to correctly save and restore
the Fragment's state when the NavController uses the
restoreState and saveState APIs.

To enable testing the save/restore separately from
NavController, this added a new restoreBackStackEntry()
API to TestNavigatorState.

Relnote: N/A
Test: new FragmentNavigatorTest tests pass
BUG: 80029773
Change-Id: I4ac261ab2195e276c93ffd63152fe2bcefa3651a
diff --git a/navigation/navigation-fragment/src/androidTest/java/androidx/navigation/fragment/FragmentNavigatorTest.kt b/navigation/navigation-fragment/src/androidTest/java/androidx/navigation/fragment/FragmentNavigatorTest.kt
index 95192e2..2498514 100644
--- a/navigation/navigation-fragment/src/androidTest/java/androidx/navigation/fragment/FragmentNavigatorTest.kt
+++ b/navigation/navigation-fragment/src/androidTest/java/androidx/navigation/fragment/FragmentNavigatorTest.kt
@@ -541,6 +541,153 @@
             .containsExactly(entry)
     }
 
+    @UiThreadTest
+    @Test
+    fun testSaveRestoreState() {
+        val entry = createBackStackEntry()
+
+        // First push an initial Fragment
+        fragmentNavigator.navigate(listOf(entry), null, null)
+        assertThat(navigatorState.backStack.value)
+            .containsExactly(entry)
+        fragmentManager.executePendingTransactions()
+        val fragment = fragmentManager.findFragmentById(R.id.container)
+        assertWithMessage("Fragment should be added")
+            .that(fragment)
+            .isNotNull()
+
+        // Now push the Fragment that we want to save
+        val replacementEntry = createBackStackEntry(SECOND_FRAGMENT, SavedStateFragment::class)
+        fragmentNavigator.navigate(listOf(replacementEntry), null, null)
+        assertThat(navigatorState.backStack.value)
+            .containsExactly(entry, replacementEntry).inOrder()
+        fragmentManager.executePendingTransactions()
+        val replacementFragment = fragmentManager.findFragmentById(R.id.container)
+        assertWithMessage("Replacement Fragment should be added")
+            .that(replacementFragment)
+            .isNotNull()
+        assertWithMessage("Replacement Fragment should be the correct type")
+            .that(replacementFragment)
+            .isInstanceOf(SavedStateFragment::class.java)
+        assertWithMessage("Replacement Fragment should be the primary navigation Fragment")
+            .that(fragmentManager.primaryNavigationFragment)
+            .isSameInstanceAs(replacementFragment)
+
+        // Save some state into the replacement fragment
+        (replacementFragment as SavedStateFragment).savedState = "test"
+
+        // Now save the Fragment
+        fragmentNavigator.popBackStack(replacementEntry, true)
+        fragmentManager.executePendingTransactions()
+        assertThat(navigatorState.backStack.value)
+            .containsExactly(entry)
+        assertWithMessage("Fragment should be the primary navigation Fragment after pop")
+            .that(fragmentManager.primaryNavigationFragment)
+            .isSameInstanceAs(fragment)
+
+        // And now restore the fragment
+        val restoredEntry = navigatorState.restoreBackStackEntry(replacementEntry)
+        fragmentNavigator.navigate(
+            listOf(restoredEntry),
+            NavOptions.Builder().setRestoreState(true).build(), null
+        )
+        assertThat(navigatorState.backStack.value)
+            .containsExactly(entry, restoredEntry).inOrder()
+        fragmentManager.executePendingTransactions()
+        val restoredFragment = fragmentManager.findFragmentById(R.id.container)
+        assertWithMessage("Restored Fragment should be added")
+            .that(restoredFragment)
+            .isNotNull()
+        assertWithMessage("Restored Fragment should be the correct type")
+            .that(restoredFragment)
+            .isInstanceOf(SavedStateFragment::class.java)
+        assertWithMessage("Restored Fragment should be the primary navigation Fragment")
+            .that(fragmentManager.primaryNavigationFragment)
+            .isSameInstanceAs(restoredFragment)
+
+        assertWithMessage("Restored Fragment should have its state restored")
+            .that((restoredFragment as SavedStateFragment).savedState)
+            .isEqualTo("test")
+    }
+
+    @UiThreadTest
+    @Test
+    fun testSaveRestoreStateAfterSaveState() {
+        val entry = createBackStackEntry()
+
+        // First push an initial Fragment
+        fragmentNavigator.navigate(listOf(entry), null, null)
+        assertThat(navigatorState.backStack.value)
+            .containsExactly(entry)
+        fragmentManager.executePendingTransactions()
+        val fragment = fragmentManager.findFragmentById(R.id.container)
+        assertWithMessage("Fragment should be added")
+            .that(fragment)
+            .isNotNull()
+
+        // Now push the Fragment that we want to save
+        val replacementEntry = createBackStackEntry(SECOND_FRAGMENT, SavedStateFragment::class)
+        fragmentNavigator.navigate(listOf(replacementEntry), null, null)
+        assertThat(navigatorState.backStack.value)
+            .containsExactly(entry, replacementEntry).inOrder()
+        fragmentManager.executePendingTransactions()
+        val replacementFragment = fragmentManager.findFragmentById(R.id.container)
+        assertWithMessage("Replacement Fragment should be added")
+            .that(replacementFragment)
+            .isNotNull()
+        assertWithMessage("Replacement Fragment should be the correct type")
+            .that(replacementFragment)
+            .isInstanceOf(SavedStateFragment::class.java)
+        assertWithMessage("Replacement Fragment should be the primary navigation Fragment")
+            .that(fragmentManager.primaryNavigationFragment)
+            .isSameInstanceAs(replacementFragment)
+
+        // Save some state into the replacement fragment
+        (replacementFragment as SavedStateFragment).savedState = "test"
+
+        // Now save the Fragment
+        fragmentNavigator.popBackStack(replacementEntry, true)
+        fragmentManager.executePendingTransactions()
+        assertThat(navigatorState.backStack.value)
+            .containsExactly(entry)
+        assertWithMessage("Fragment should be the primary navigation Fragment after pop")
+            .that(fragmentManager.primaryNavigationFragment)
+            .isSameInstanceAs(fragment)
+
+        // Create a new FragmentNavigator, replacing the previous one
+        val savedState = fragmentNavigator.onSaveState() as Bundle
+        fragmentNavigator = FragmentNavigator(
+            emptyActivity,
+            fragmentManager, R.id.container
+        )
+        fragmentNavigator.onAttach(navigatorState)
+        fragmentNavigator.onRestoreState(savedState)
+
+        // And now restore the fragment
+        val restoredEntry = navigatorState.restoreBackStackEntry(replacementEntry)
+        fragmentNavigator.navigate(
+            listOf(restoredEntry),
+            NavOptions.Builder().setRestoreState(true).build(), null
+        )
+        assertThat(navigatorState.backStack.value)
+            .containsExactly(entry, restoredEntry).inOrder()
+        fragmentManager.executePendingTransactions()
+        val restoredFragment = fragmentManager.findFragmentById(R.id.container)
+        assertWithMessage("Restored Fragment should be added")
+            .that(restoredFragment)
+            .isNotNull()
+        assertWithMessage("Restored Fragment should be the correct type")
+            .that(restoredFragment)
+            .isInstanceOf(SavedStateFragment::class.java)
+        assertWithMessage("Restored Fragment should be the primary navigation Fragment")
+            .that(fragmentManager.primaryNavigationFragment)
+            .isSameInstanceAs(restoredFragment)
+
+        assertWithMessage("Restored Fragment should have its state restored")
+            .that((restoredFragment as SavedStateFragment).savedState)
+            .isEqualTo("test")
+    }
+
     @Test
     fun testToString() {
         val destination = fragmentNavigator.createDestination().apply {
@@ -583,6 +730,20 @@
     }
 }
 
+class SavedStateFragment : Fragment() {
+    var savedState: String? = null
+
+    override fun onCreate(savedInstanceState: Bundle?) {
+        super.onCreate(savedInstanceState)
+        savedState = savedInstanceState?.getString("savedState")
+    }
+
+    override fun onSaveInstanceState(outState: Bundle) {
+        super.onSaveInstanceState(outState)
+        outState.putString("savedState", savedState)
+    }
+}
+
 class NonEmptyConstructorFragment(val test: String) : Fragment()
 
 class NonEmptyFragmentFactory : FragmentFactory() {
diff --git a/navigation/navigation-fragment/src/main/java/androidx/navigation/fragment/FragmentNavigator.kt b/navigation/navigation-fragment/src/main/java/androidx/navigation/fragment/FragmentNavigator.kt
index 35b6c00..7e575fc 100644
--- a/navigation/navigation-fragment/src/main/java/androidx/navigation/fragment/FragmentNavigator.kt
+++ b/navigation/navigation-fragment/src/main/java/androidx/navigation/fragment/FragmentNavigator.kt
@@ -23,6 +23,7 @@
 import androidx.annotation.CallSuper
 import androidx.annotation.IdRes
 import androidx.core.content.res.use
+import androidx.core.os.bundleOf
 import androidx.fragment.app.Fragment
 import androidx.fragment.app.FragmentManager
 import androidx.navigation.NavBackStackEntry
@@ -51,6 +52,7 @@
     private val fragmentManager: FragmentManager,
     private val containerId: Int
 ) : Navigator<Destination>() {
+    private val savedIds = mutableSetOf<String>()
 
     /**
      * {@inheritDoc}
@@ -71,10 +73,33 @@
             )
             return
         }
-        fragmentManager.popBackStack(
-            popUpTo.id,
-            FragmentManager.POP_BACK_STACK_INCLUSIVE
-        )
+        if (savedState) {
+            val beforePopList = state.backStack.value
+            val initialEntry = beforePopList.first()
+            // Get the set of entries that are going to be popped
+            val poppedList = beforePopList.subList(
+                beforePopList.indexOf(popUpTo),
+                beforePopList.size
+            )
+            // Now go through the list in reversed order (i.e., started from the most added)
+            // and save the back stack state of each.
+            for (entry in poppedList.reversed()) {
+                if (entry == initialEntry) {
+                    Log.i(
+                        TAG,
+                        "FragmentManager cannot save the state of the initial destination $entry"
+                    )
+                } else {
+                    fragmentManager.saveBackStack(entry.id)
+                    savedIds += entry.id
+                }
+            }
+        } else {
+            fragmentManager.popBackStack(
+                popUpTo.id,
+                FragmentManager.POP_BACK_STACK_INCLUSIVE
+            )
+        }
         state.pop(popUpTo, savedState)
     }
 
@@ -143,6 +168,19 @@
         navOptions: NavOptions?,
         navigatorExtras: Navigator.Extras?
     ) {
+        val backStack = state.backStack.value
+        val initialNavigation = backStack.isEmpty()
+        val restoreState = (
+            navOptions != null && !initialNavigation &&
+                navOptions.shouldRestoreState() &&
+                savedIds.remove(entry.id)
+            )
+        if (restoreState) {
+            // Restore back stack does all the work to restore the entry
+            fragmentManager.restoreBackStack(entry.id)
+            state.add(entry)
+            return
+        }
         val destination = entry.destination as Destination
         val args = entry.arguments
         var className = destination.className
@@ -166,8 +204,6 @@
         ft.replace(containerId, frag)
         ft.setPrimaryNavigationFragment(frag)
         @IdRes val destId = destination.id
-        val backStack = state.backStack.value
-        val initialNavigation = backStack.isEmpty()
         // TODO Build first class singleTop behavior for fragments
         val isSingleTopReplacement = (
             navOptions != null && !initialNavigation &&
@@ -211,6 +247,21 @@
         }
     }
 
+    public override fun onSaveState(): Bundle? {
+        if (savedIds.isEmpty()) {
+            return null
+        }
+        return bundleOf(KEY_SAVED_IDS to ArrayList(savedIds))
+    }
+
+    public override fun onRestoreState(savedState: Bundle) {
+        val savedIds = savedState.getStringArrayList(KEY_SAVED_IDS)
+        if (savedIds != null) {
+            this.savedIds.clear()
+            this.savedIds += savedIds
+        }
+    }
+
     /**
      * NavDestination specific to [FragmentNavigator]
      */
@@ -351,5 +402,6 @@
 
     private companion object {
         private const val TAG = "FragmentNavigator"
+        private const val KEY_SAVED_IDS = "androidx-nav-fragment:navigator:savedIds"
     }
 }
diff --git a/navigation/navigation-testing/api/current.txt b/navigation/navigation-testing/api/current.txt
index b862bb3..8fc24e8 100644
--- a/navigation/navigation-testing/api/current.txt
+++ b/navigation/navigation-testing/api/current.txt
@@ -15,6 +15,7 @@
     ctor public TestNavigatorState(optional android.content.Context? context, optional kotlinx.coroutines.CoroutineDiser coroutineDiser);
     ctor public TestNavigatorState(optional android.content.Context? context);
     method public androidx.navigation.NavBackStackEntry createBackStackEntry(androidx.navigation.NavDestination destination, android.os.Bundle? arguments);
+    method public androidx.navigation.NavBackStackEntry restoreBackStackEntry(androidx.navigation.NavBackStackEntry previouslySavedEntry);
   }
 
 }
diff --git a/navigation/navigation-testing/api/public_plus_experimental_current.txt b/navigation/navigation-testing/api/public_plus_experimental_current.txt
index b862bb3..8fc24e8 100644
--- a/navigation/navigation-testing/api/public_plus_experimental_current.txt
+++ b/navigation/navigation-testing/api/public_plus_experimental_current.txt
@@ -15,6 +15,7 @@
     ctor public TestNavigatorState(optional android.content.Context? context, optional kotlinx.coroutines.CoroutineDiser coroutineDiser);
     ctor public TestNavigatorState(optional android.content.Context? context);
     method public androidx.navigation.NavBackStackEntry createBackStackEntry(androidx.navigation.NavDestination destination, android.os.Bundle? arguments);
+    method public androidx.navigation.NavBackStackEntry restoreBackStackEntry(androidx.navigation.NavBackStackEntry previouslySavedEntry);
   }
 
 }
diff --git a/navigation/navigation-testing/api/restricted_current.txt b/navigation/navigation-testing/api/restricted_current.txt
index b862bb3..8fc24e8 100644
--- a/navigation/navigation-testing/api/restricted_current.txt
+++ b/navigation/navigation-testing/api/restricted_current.txt
@@ -15,6 +15,7 @@
     ctor public TestNavigatorState(optional android.content.Context? context, optional kotlinx.coroutines.CoroutineDiser coroutineDiser);
     ctor public TestNavigatorState(optional android.content.Context? context);
     method public androidx.navigation.NavBackStackEntry createBackStackEntry(androidx.navigation.NavDestination destination, android.os.Bundle? arguments);
+    method public androidx.navigation.NavBackStackEntry restoreBackStackEntry(androidx.navigation.NavBackStackEntry previouslySavedEntry);
   }
 
 }
diff --git a/navigation/navigation-testing/src/main/java/androidx/navigation/testing/TestNavigatorState.kt b/navigation/navigation-testing/src/main/java/androidx/navigation/testing/TestNavigatorState.kt
index a3de911..9140e76 100644
--- a/navigation/navigation-testing/src/main/java/androidx/navigation/testing/TestNavigatorState.kt
+++ b/navigation/navigation-testing/src/main/java/androidx/navigation/testing/TestNavigatorState.kt
@@ -64,6 +64,8 @@
         }
     }
 
+    private val savedStates = mutableMapOf<String, Bundle>()
+
     override fun createBackStackEntry(
         destination: NavDestination,
         arguments: Bundle?
@@ -71,6 +73,23 @@
         context, destination, arguments, lifecycleOwner, viewModelStoreProvider
     )
 
+    /**
+     * Restore a previously saved [NavBackStackEntry]. You must have previously called
+     * [pop] with [previouslySavedEntry] and `true`.
+     */
+    public fun restoreBackStackEntry(previouslySavedEntry: NavBackStackEntry): NavBackStackEntry {
+        val savedState = checkNotNull(savedStates[previouslySavedEntry.id]) {
+            "restoreBackStackEntry(previouslySavedEntry) must be passed a NavBackStackEntry " +
+                "that was previously popped with popBackStack(previouslySavedEntry, true)"
+        }
+        return NavBackStackEntry.create(
+            context,
+            previouslySavedEntry.destination, previouslySavedEntry.arguments,
+            lifecycleOwner, viewModelStoreProvider,
+            previouslySavedEntry.id, savedState
+        )
+    }
+
     override fun add(backStackEntry: NavBackStackEntry) {
         super.add(backStackEntry)
         updateMaxLifecycle()
@@ -80,10 +99,13 @@
         val beforePopList = backStack.value
         val poppedList = beforePopList.subList(beforePopList.indexOf(popUpTo), beforePopList.size)
         super.pop(popUpTo, saveState)
-        updateMaxLifecycle(poppedList)
+        updateMaxLifecycle(poppedList, saveState)
     }
 
-    private fun updateMaxLifecycle(poppedList: List<NavBackStackEntry> = emptyList()) {
+    private fun updateMaxLifecycle(
+        poppedList: List<NavBackStackEntry> = emptyList(),
+        saveState: Boolean = false
+    ) {
         runBlocking(coroutineDiser) {
             // NavBackStackEntry Lifecycles must be updated on the main thread
             // as per the contract within Lifecycle, so we explicitly swap to the main thread
@@ -91,6 +113,13 @@
             withContext(Disers.Main.immediate) {
                 // Mark all removed NavBackStackEntries as DESTROYED
                 for (entry in poppedList.reversed()) {
+                    if (saveState) {
+                        // Move the NavBackStackEntry to the stopped state, then save its state
+                        entry.maxLifecycle = Lifecycle.State.CREATED
+                        val savedState = Bundle()
+                        entry.saveState(savedState)
+                        savedStates[entry.id] = savedState
+                    }
                     entry.maxLifecycle = Lifecycle.State.DESTROYED
                 }
                 // Now go through the current list of destinations, updating their Lifecycle state