mirror of
https://github.com/3b1b/manim.git
synced 2025-07-27 03:52:20 +08:00

Since the reload logic no longer relies on any state, the relevant loop is simple enough that it feels clearest to include it in the main entry point file.
177 lines
6.3 KiB
Python
177 lines
6.3 KiB
Python
from __future__ import annotations
|
|
|
|
import builtins
|
|
import importlib
|
|
import os
|
|
import sys
|
|
import sysconfig
|
|
|
|
from manimlib.config import get_global_config
|
|
from manimlib.logger import log
|
|
|
|
Module = importlib.util.types.ModuleType
|
|
|
|
|
|
class ModuleLoader:
|
|
"""
|
|
Utility class to load a module from a file and handle its imports.
|
|
|
|
Most parts of this class are only needed for the reload functionality,
|
|
while the `get_module` method is the main entry point to import a module.
|
|
"""
|
|
|
|
@staticmethod
|
|
def get_module(file_name: str | None, is_during_reload=False) -> Module | None:
|
|
"""
|
|
Imports a module from a file and returns it.
|
|
|
|
During reload (when the user calls `reload()` in the IPython shell), we
|
|
also track the imported modules and reload them as well (they would be
|
|
cached otherwise). See the reload_manager where the reload parameter is set.
|
|
|
|
Note that `exec_module()` is called twice when reloading a module:
|
|
1. In exec_module_and_track_imports to track the imports
|
|
2. Here to actually execute the module again with the respective
|
|
imported modules reloaded.
|
|
"""
|
|
if file_name is None:
|
|
return None
|
|
|
|
module_name = file_name.replace(os.sep, ".").replace(".py", "")
|
|
spec = importlib.util.spec_from_file_location(module_name, file_name)
|
|
module = importlib.util.module_from_spec(spec)
|
|
|
|
if is_during_reload:
|
|
imported_modules = ModuleLoader._exec_module_and_track_imports(spec, module)
|
|
reloaded_modules_tracker = set()
|
|
ModuleLoader._reload_modules(imported_modules, reloaded_modules_tracker)
|
|
|
|
spec.loader.exec_module(module)
|
|
return module
|
|
|
|
@staticmethod
|
|
def _exec_module_and_track_imports(spec, module: Module) -> set[str]:
|
|
"""
|
|
Executes the given module (imports it) and returns all the modules that
|
|
are imported during its execution.
|
|
|
|
This is achieved by replacing the __import__ function with a custom one
|
|
that tracks the imported modules. At the end, the original __import__
|
|
built-in function is restored.
|
|
"""
|
|
imported_modules: set[str] = set()
|
|
original_import = builtins.__import__
|
|
|
|
def tracked_import(name, globals=None, locals=None, fromlist=(), level=0):
|
|
"""
|
|
Custom __import__ function that does exactly the same as the original
|
|
one, but also tracks the imported modules by means of adding their
|
|
names to a set.
|
|
"""
|
|
result = original_import(name, globals, locals, fromlist, level)
|
|
imported_modules.add(name)
|
|
return result
|
|
|
|
builtins.__import__ = tracked_import
|
|
|
|
try:
|
|
module_name = module.__name__
|
|
log.debug('Reloading module "%s"', module_name)
|
|
|
|
spec.loader.exec_module(module)
|
|
finally:
|
|
builtins.__import__ = original_import
|
|
|
|
return imported_modules
|
|
|
|
@staticmethod
|
|
def _reload_modules(modules: set[str], reloaded_modules_tracker: set[str]):
|
|
"""
|
|
Out of the given modules, reloads the ones that were not already imported.
|
|
|
|
We skip modules that are not user-defined (see `is_user_defined_module()`).
|
|
"""
|
|
for mod in modules:
|
|
if mod in reloaded_modules_tracker:
|
|
continue
|
|
|
|
if not ModuleLoader._is_user_defined_module(mod):
|
|
continue
|
|
|
|
module = sys.modules[mod]
|
|
ModuleLoader._deep_reload(module, reloaded_modules_tracker)
|
|
|
|
reloaded_modules_tracker.add(mod)
|
|
|
|
@staticmethod
|
|
def _is_user_defined_module(mod: str) -> bool:
|
|
"""
|
|
Returns whether the given module is user-defined or not.
|
|
|
|
A module is considered user-defined if
|
|
- it is not part of the standard library
|
|
- AND it is not an external library (site-packages or dist-packages)
|
|
"""
|
|
if mod not in sys.modules:
|
|
return False
|
|
|
|
if mod in sys.builtin_module_names:
|
|
return False
|
|
|
|
module = sys.modules[mod]
|
|
module_path = getattr(module, "__file__", None)
|
|
if module_path is None:
|
|
return False
|
|
module_path = os.path.abspath(module_path)
|
|
|
|
# External libraries (site-packages or dist-packages), e.g. numpy
|
|
if "site-packages" in module_path or "dist-packages" in module_path:
|
|
return False
|
|
|
|
# Standard lib
|
|
standard_lib_path = sysconfig.get_path("stdlib")
|
|
if module_path.startswith(standard_lib_path):
|
|
return False
|
|
|
|
return True
|
|
|
|
@staticmethod
|
|
def _deep_reload(module: Module, reloaded_modules_tracker: set[str]):
|
|
"""
|
|
Recursively reloads modules imported by the given module.
|
|
|
|
Only user-defined modules are reloaded, see `is_user_defined_module()`.
|
|
"""
|
|
ignore_manimlib_modules = get_global_config()["ignore_manimlib_modules_on_reload"]
|
|
if ignore_manimlib_modules and module.__name__.startswith("manimlib"):
|
|
return
|
|
if module.__name__.startswith("manimlib.config"):
|
|
# We don't want to reload global config
|
|
return
|
|
|
|
if not hasattr(module, "__dict__"):
|
|
return
|
|
|
|
# Prevent reloading the same module multiple times
|
|
if module.__name__ in reloaded_modules_tracker:
|
|
return
|
|
reloaded_modules_tracker.add(module.__name__)
|
|
|
|
# Recurse for all imported modules
|
|
for _attr_name, attr_value in module.__dict__.items():
|
|
if isinstance(attr_value, Module):
|
|
if ModuleLoader._is_user_defined_module(attr_value.__name__):
|
|
ModuleLoader._deep_reload(attr_value, reloaded_modules_tracker)
|
|
|
|
# Also reload modules that are part of a class or function
|
|
# e.g. when importing `from custom_module import CustomClass`
|
|
elif hasattr(attr_value, "__module__"):
|
|
attr_module_name = attr_value.__module__
|
|
if ModuleLoader._is_user_defined_module(attr_module_name):
|
|
attr_module = sys.modules[attr_module_name]
|
|
ModuleLoader._deep_reload(attr_module, reloaded_modules_tracker)
|
|
|
|
# Reload
|
|
log.debug('Reloading module "%s"', module.__name__)
|
|
importlib.reload(module)
|