@@ -111,6 +111,12 @@ def test_set_non_numpy_tensor(self):
111
111
self .assertEqual ([2 , 3 , 4 , 5 ], result ["values" ])
112
112
self .assertEqual ([2 , 2 ], result ["shape" ])
113
113
114
+ con .tensorset ("x" , (1 , 1 , 0 , 0 ), dtype = "bool" , shape = (2 , 2 ))
115
+ result = con .tensorget ("x" , as_numpy = False )
116
+ self .assertEqual ([True , True , False , False ], result ["values" ])
117
+ self .assertEqual ([2 , 2 ], result ["shape" ])
118
+ self .assertEqual ("BOOL" , result ["dtype" ])
119
+
114
120
with self .assertRaises (TypeError ):
115
121
con .tensorset ("x" , (2 , 3 , 4 , 5 ), dtype = "wrongtype" , shape = (2 , 2 ))
116
122
con .tensorset ("x" , (2 , 3 , 4 , 5 ), dtype = "int8" , shape = (2 , 2 ))
@@ -144,6 +150,12 @@ def test_numpy_tensor(self):
144
150
values = con .tensorget ("x" )
145
151
self .assertEqual (values .dtype , np .float64 )
146
152
153
+ input_array = np .array ([True , False ])
154
+ con .tensorset ("x" , input_array )
155
+ values = con .tensorget ("x" )
156
+ self .assertEqual (values .dtype , "bool" )
157
+ self .assertTrue (np .array_equal (values , [True , False ]))
158
+
147
159
input_array = np .array ([2 , 3 ])
148
160
con .tensorset ("x" , input_array )
149
161
values = con .tensorget ("x" )
0 commit comments