|  | 
| 13 | 13 | import logging | 
| 14 | 14 | import importlib | 
| 15 | 15 | import itertools | 
|  | 16 | +from types import ModuleType | 
| 16 | 17 | from typing import Any, Dict, List, Tuple, Union, Optional | 
| 17 | 18 | 
 | 
| 18 | 19 | from .utils import text_, bytes_ | 
| @@ -75,31 +76,54 @@ def load( | 
| 75 | 76 |             # this plugin_ is implementing | 
| 76 | 77 |             base_klass = None | 
| 77 | 78 |             for k in mro: | 
| 78 |  | -                if bytes_(k.__name__) in p: | 
|  | 79 | +                if bytes_(k.__qualname__) in p: | 
| 79 | 80 |                     base_klass = k | 
| 80 | 81 |                     break | 
| 81 | 82 |             if base_klass is None: | 
| 82 | 83 |                 raise ValueError('%s is NOT a valid plugin' % text_(plugin_)) | 
| 83 |  | -            if klass not in p[bytes_(base_klass.__name__)]: | 
| 84 |  | -                p[bytes_(base_klass.__name__)].append(klass) | 
| 85 |  | -            logger.info('Loaded plugin %s.%s', module_name, klass.__name__) | 
|  | 84 | +            if klass not in p[bytes_(base_klass.__qualname__)]: | 
|  | 85 | +                p[bytes_(base_klass.__qualname__)].append(klass) | 
|  | 86 | +            logger.info('Loaded plugin %s.%s', module_name, klass.__qualname__) | 
| 86 | 87 |         # print(p) | 
| 87 | 88 |         return p | 
| 88 | 89 | 
 | 
| 89 | 90 |     @staticmethod | 
| 90 | 91 |     def importer(plugin: Union[bytes, type]) -> Tuple[type, str]: | 
| 91 | 92 |         """Import and returns the plugin.""" | 
| 92 | 93 |         if isinstance(plugin, type): | 
| 93 |  | -            return (plugin, '__main__') | 
|  | 94 | +            if inspect.isclass(plugin): | 
|  | 95 | +                return (plugin, plugin.__module__ or '__main__') | 
|  | 96 | +            raise ValueError('%s is not a valid reference to a plugin class' % text_(plugin)) | 
| 94 | 97 |         plugin_ = text_(plugin.strip()) | 
| 95 | 98 |         assert plugin_ != '' | 
| 96 |  | -        module_name, klass_name = plugin_.rsplit(text_(DOT), 1) | 
| 97 |  | -        klass = getattr( | 
| 98 |  | -            importlib.import_module( | 
| 99 |  | -                module_name.replace( | 
| 100 |  | -                    os.path.sep, text_(DOT), | 
| 101 |  | -                ), | 
| 102 |  | -            ), | 
| 103 |  | -            klass_name, | 
| 104 |  | -        ) | 
|  | 99 | +        path = plugin_.split(text_(DOT)) | 
|  | 100 | +        klass = None | 
|  | 101 | + | 
|  | 102 | +        def locate_klass(klass_module_name: str, klass_path: List[str]) -> Union[type, None]: | 
|  | 103 | +            klass_module_name = klass_module_name.replace(os.path.sep, text_(DOT)) | 
|  | 104 | +            try: | 
|  | 105 | +                klass_module = importlib.import_module(klass_module_name) | 
|  | 106 | +            except ModuleNotFoundError: | 
|  | 107 | +                return None | 
|  | 108 | +            klass_container: Union[ModuleType, type] = klass_module | 
|  | 109 | +            for klass_path_part in klass_path: | 
|  | 110 | +                try: | 
|  | 111 | +                    klass_container = getattr(klass_container, klass_path_part) | 
|  | 112 | +                except AttributeError: | 
|  | 113 | +                    return None | 
|  | 114 | +            if not isinstance(klass_container, type) or not inspect.isclass(klass_container): | 
|  | 115 | +                return None | 
|  | 116 | +            return klass_container | 
|  | 117 | + | 
|  | 118 | +        module_name = None | 
|  | 119 | +        for module_name_parts in range(len(path) - 1, 0, -1): | 
|  | 120 | +            module_name = '.'.join(path[0:module_name_parts]) | 
|  | 121 | +            klass = locate_klass(module_name, path[module_name_parts:]) | 
|  | 122 | +            if klass: | 
|  | 123 | +                break | 
|  | 124 | +        if klass is None: | 
|  | 125 | +            module_name = '__main__' | 
|  | 126 | +            klass = locate_klass(module_name, path) | 
|  | 127 | +        if klass is None or module_name is None: | 
|  | 128 | +            raise ValueError('%s is not resolvable as a plugin class' % text_(plugin)) | 
| 105 | 129 |         return (klass, module_name) | 
0 commit comments