-
Notifications
You must be signed in to change notification settings - Fork 6
Meta Kernel Generation #96
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
d59a7d7
651d6a1
311117c
cc2679b
22b977e
d9ad0bf
b094c87
197f0f9
efa9dc2
0cba7ba
f10b992
ec21b58
d602959
758b205
6fcef26
0bd7396
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,142 @@ | ||
| from .arch import fixArchitectureGlobal | ||
| from .codegen.code import Cpp | ||
|
|
||
| import os | ||
|
|
||
| class MetaGenerator: | ||
| def __init__(self, templateType): | ||
| self.templateType = templateType | ||
| self.generators = [] | ||
|
|
||
| def add_generator(self, template, generator, *args, **kwargs): | ||
| assert len(self.templateType) == len(template) | ||
| self.generators += [{ | ||
| 'name': kwargs["name"] if "name" in kwargs else str(len(self.generators)), | ||
| 'template': template, | ||
| 'generator': generator, | ||
| 'args': args, | ||
| 'kwargs': kwargs | ||
| }] | ||
|
|
||
| def compile_list(self, outputDir=''): | ||
| outfiles = [] | ||
| for gendata in self.generators: | ||
| outdirname = f'metagen_{gendata["name"]}' | ||
| outdir = os.path.join(outputDir, outdirname) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Where is
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Whoops; missed that one. Made it a function arg. |
||
|
|
||
| genout = [] | ||
| for file in ['tensor', 'init', 'kernel', 'test-kernel']: | ||
| genout += [file] | ||
| outfiles += [genout] | ||
| return outfiles | ||
|
|
||
| def generate_single(self, index, outputDir='', namespace='yateto'): | ||
| namespacepfx = 'yatetometagen' | ||
| gendata = self.generators[index] | ||
| subnamespace = f'{namespace}::{namespacepfx}_{gendata["name"]}' | ||
| outdirname = f'metagen_{gendata["name"]}' | ||
| outdir = os.path.join(outputDir, outdirname) | ||
| os.makedirs(outdir, exist_ok=True) | ||
|
|
||
| generator = gendata['generator'] | ||
| template = gendata['template'] | ||
| args = gendata['args'] | ||
| kwargs = gendata['kwargs'] | ||
|
|
||
| fixArchitectureGlobal(generator.arch()) | ||
| result = generator.generate(*args, **kwargs, namespace=subnamespace, outputDir=outdir) | ||
|
|
||
| tensors = {} | ||
| kernels = {} | ||
|
|
||
| for tensor in result['tensors']: | ||
| tensors[tensor] = (subnamespace, template) | ||
| for kernel in result['kernels']: | ||
| kernels[kernel] = (subnamespace, template) | ||
|
|
||
| return tensors, kernels | ||
|
|
||
| def generate(self, outputDir='', namespace='yateto', includes=[], declarationsTensors=[], declarationsKernels=[], precompiled=None): | ||
| tensors = {} | ||
| kernels = {} | ||
|
|
||
| for tensor in declarationsTensors: | ||
| tensors[tensor] = [] | ||
| for tensor in declarationsKernels: | ||
| kernels[tensor] = [] | ||
|
|
||
| for index in range(len(self.generators)): | ||
| if precompiled is None: | ||
| local_tensors, local_kernels = self.generate_single(index, outputDir, namespace) | ||
| else: | ||
| local_tensors, local_kernels = precompiled[index] | ||
|
|
||
| for tensor in local_tensors: | ||
| if tensor not in tensors: | ||
| tensors[tensor] = [] | ||
| tensors[tensor] += [local_tensors[tensor]] | ||
| for kernel in local_kernels: | ||
| if kernel not in kernels: | ||
| kernels[kernel] = [] | ||
| kernels[kernel] += [local_kernels[kernel]] | ||
|
|
||
| nspuppercase = namespace.upper() | ||
|
|
||
| def headerForward(name, data): | ||
| upper = name.upper() | ||
| with Cpp(os.path.join(outputDir, f'{name}.h')) as header: | ||
| with header.HeaderGuard(f'METAGEN_{nspuppercase}_{upper}_H_'): | ||
| for path in includes: | ||
| header.include(path) | ||
| for gendata in self.generators: | ||
| outdirname = f'metagen_{gendata["name"]}' | ||
| header.include(f'{outdirname}/{name}.h') | ||
| with header.Namespace(namespace): | ||
| for entry in data: | ||
| self.template(header, entry, data[entry], f'{name}') | ||
|
|
||
|
|
||
| headerForward('tensor', tensors) | ||
| headerForward('init', tensors) | ||
| headerForward('kernel', kernels) | ||
|
|
||
| def cppForward(name): | ||
| with Cpp(os.path.join(outputDir, f'{name}.cpp')) as header: | ||
| for gendata in self.generators: | ||
| outdirname = f'metagen_{gendata["name"]}' | ||
| header.include(f'{outdirname}/{name}.cpp') | ||
|
|
||
| cppForward('tensor') | ||
| cppForward('init') | ||
| cppForward('kernel') | ||
| cppForward('test-kernel') | ||
|
|
||
| def namespacing(self, header, spaces, inner): | ||
| if len(spaces) == 0: | ||
| inner() | ||
| else: | ||
| with header.Namespace(spaces[0]): | ||
| self.namespacing(header, spaces[1:], inner) | ||
|
|
||
| def template(self, header, prename, foundin, subnsp): | ||
| splitname = prename.split('::') | ||
|
|
||
| assert len(splitname) > 0 | ||
|
|
||
| def inner(): | ||
| name = splitname[-1] | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How are we guaranteeing that
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That actually shouldn't happen; since even without a
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What happens if
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If you mean about the split, then: (i.e. we just get one element which is the whole string) (can probably be confirmed in any Python console around if needed) |
||
| fullname = '::'.join(splitname[:-1] + [subnsp, splitname[-1]]) | ||
| escname = name.replace(':', '_') | ||
| internalName = f'Internal_{escname}' | ||
|
|
||
| templatetypes = ', '.join(f'{typ} Arg{i}' for i, typ in enumerate(self.templateType)) | ||
| templateargs = ', '.join(f'Arg{i}' for i, _ in enumerate(self.templateType)) | ||
|
|
||
| with header.Namespace('internal'): | ||
| header(f'template<{templatetypes}> struct {internalName} {"{"} using Type = void; {"}"};') | ||
| for gnsp, spec in foundin: | ||
| spectext = ', '.join(str(specpart) for specpart in spec) | ||
| header(f'template<> struct {internalName}<{spectext}> {"{"} using Type = ::{gnsp}::{fullname}; {"}"};') | ||
| header(f'template<{templatetypes}> using {name} = typename internal::{internalName}<{templateargs}>::Type;') | ||
|
|
||
| self.namespacing(header, splitname[:-1] + [subnsp], inner) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not a review comment, but a general query: how do you choose which variable name to be camel case, and which variable to be normal like this? For example:
outdirnameis not camelcase, but the function argumentoutputDiris camelcase. Is it that all internal variables are not camel cases, and then class members, and function arguments are camel cases?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's say... It's a bit arbitrary. The
outputDirbeing camel case is more a relic since all of Yateto uses that convention quite a bit; especially in the older commits it seems.And so,
outputDiris usually in camel case within Yateto for now.Usually, per the official Python style guide, snake case is the "official" way to name things; though snake case can sometimes appear a bit hard to read IMO (it just gives you less to focus on—visually at least).
... that made me think; at some point we might wanna apply the whole usual swath of Python linters on Yateto. I've started setting up a pre-commit in #109.