-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathdrivers.py
More file actions
39 lines (32 loc) · 992 Bytes
/
drivers.py
File metadata and controls
39 lines (32 loc) · 992 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import tvm
class GPUTemplate:
target= "cuda"
device=tvm.cuda(1)
kind="GPU"
class CPUTemplate:
target= "llvm"
device=tvm.cpu(0)
kind = "CPU"
class DeviceDriver:
def __init__(self,kind="Unknown",target="Unknown",device=None):
self.target= target
self.device= device
self.kind = kind
class CPU(DeviceDriver):
def __init__(self,kind="",target="",device=None):
super().__init__(CPUTemplate.kind,CPUTemplate.target,CPUTemplate.device)
if kind !="":
self.kind=kind
if target !="":
self.target=target
if device is not None:
self.device=device
class GPU(DeviceDriver):
def __init__(self,kind="",target="",device=None):
super().__init__(GPUTemplate.kind,GPUTemplate.target,GPUTemplate.device)
if kind !="":
self.kind=kind
if target !="":
self.target=target
if device is not None:
self.device=device