|
1 | | -import { anyPType, ContractClassPType, FunctionPType, PType, SourceLocation, typeRegistry, TypeResolver } from '@algorandfoundation/puya-ts' |
| 1 | +import { |
| 2 | + anyPType, |
| 3 | + ARC4StructType, |
| 4 | + ARC4TupleType, |
| 5 | + BoxMapPType, |
| 6 | + BoxPType, |
| 7 | + ContractClassPType, |
| 8 | + DynamicArrayType, |
| 9 | + FunctionPType, |
| 10 | + GlobalStateType, |
| 11 | + LocalStateType, |
| 12 | + PType, |
| 13 | + SourceLocation, |
| 14 | + StaticArrayType, |
| 15 | + TypeResolver, |
| 16 | + UFixedNxMType, |
| 17 | + UintNType, |
| 18 | +} from '@algorandfoundation/puya-ts' |
2 | 19 | import ts from 'typescript' |
3 | 20 | import type { TypeInfo } from '../encoders' |
4 | | -import { DeliberateAny } from '../typescript-helpers' |
| 21 | +import { instanceOfAny } from '../typescript-helpers' |
5 | 22 | import { TransformerConfig } from './index' |
6 | 23 | import { nodeFactory } from './node-factory' |
7 | 24 | import { |
@@ -61,9 +78,64 @@ export class SourceFileVisitor { |
61 | 78 | return new ClassVisitor(this.context, this.helper, node).result() |
62 | 79 | } |
63 | 80 |
|
| 81 | + // capture generic type info for variable initialising outside class and function declarations |
| 82 | + // e.g. `const x = new UintN<32>(42) |
| 83 | + if (ts.isVariableDeclaration(node) && node.initializer) { |
| 84 | + return new VariableInitializerVisitor(this.context, this.helper, node).result() |
| 85 | + } |
| 86 | + |
| 87 | + return ts.visitEachChild(node, this.visit, this.context) |
| 88 | + } |
| 89 | +} |
| 90 | + |
| 91 | +class ExpressionVisitor { |
| 92 | + constructor( |
| 93 | + private context: ts.TransformationContext, |
| 94 | + private helper: VisitorHelper, |
| 95 | + private expressionNode: ts.Expression, |
| 96 | + ) {} |
| 97 | + |
| 98 | + public result(): ts.Expression { |
| 99 | + return this.visit(this.expressionNode) as ts.Expression |
| 100 | + } |
| 101 | + |
| 102 | + private visit = (node: ts.Node): ts.Node => { |
| 103 | + if (ts.isCallExpression(node) || ts.isNewExpression(node)) { |
| 104 | + let type = this.helper.resolveType(node) |
| 105 | + |
| 106 | + // `voted = LocalState<uint64>()` is resolved to FunctionPType with returnType LocalState<uint64> |
| 107 | + if (type instanceof FunctionPType) type = type.returnType |
| 108 | + |
| 109 | + if (isGenericType(type)) { |
| 110 | + const info = getGenericTypeInfo(type) |
| 111 | + return nodeFactory.captureGenericTypeInfo(ts.visitEachChild(node, this.visit, this.context), JSON.stringify(info)) |
| 112 | + } |
| 113 | + } |
64 | 114 | return ts.visitEachChild(node, this.visit, this.context) |
65 | 115 | } |
66 | 116 | } |
| 117 | +class VariableInitializerVisitor { |
| 118 | + constructor( |
| 119 | + private context: ts.TransformationContext, |
| 120 | + private helper: VisitorHelper, |
| 121 | + private declarationNode: ts.VariableDeclaration, |
| 122 | + ) {} |
| 123 | + |
| 124 | + public result(): ts.VariableDeclaration { |
| 125 | + const initializerNode = this.declarationNode.initializer |
| 126 | + if (!initializerNode) return this.declarationNode |
| 127 | + |
| 128 | + const updatedInitializer = new ExpressionVisitor(this.context, this.helper, initializerNode).result() |
| 129 | + if (updatedInitializer === initializerNode) return this.declarationNode |
| 130 | + return factory.updateVariableDeclaration( |
| 131 | + this.declarationNode, |
| 132 | + this.declarationNode.name, |
| 133 | + this.declarationNode.exclamationToken, |
| 134 | + this.declarationNode.type, |
| 135 | + updatedInitializer, |
| 136 | + ) |
| 137 | + } |
| 138 | +} |
67 | 139 |
|
68 | 140 | class FunctionOrMethodVisitor { |
69 | 141 | constructor( |
@@ -110,17 +182,20 @@ class FunctionOrMethodVisitor { |
110 | 182 | * }) |
111 | 183 | * ``` |
112 | 184 | */ |
113 | | - if (this.isFunction && ts.isVariableDeclaration(node) && node.initializer && ts.isCallExpression(node.initializer)) { |
114 | | - const initializerNode = node.initializer |
115 | | - let type = this.helper.resolveType(initializerNode) |
| 185 | + if (this.isFunction && ts.isVariableDeclaration(node) && node.initializer) { |
| 186 | + return new VariableInitializerVisitor(this.context, this.helper, node).result() |
| 187 | + } |
116 | 188 |
|
117 | | - // `voted = LocalState<uint64>()` is resolved to FunctionPType with returnType LocalState<uint64> |
118 | | - if (type instanceof FunctionPType) type = type.returnType |
119 | | - if (typeRegistry.isGeneric(type)) { |
120 | | - const info = getGenericTypeInfo(type) |
121 | | - const updatedInitializer = nodeFactory.captureGenericTypeInfo(initializerNode, JSON.stringify(info)) |
122 | | - return factory.updateVariableDeclaration(node, node.name, node.exclamationToken, node.type, updatedInitializer) |
123 | | - } |
| 189 | + /* |
| 190 | + * capture generic type info in test functions and swap arc4 types with implementation; e.g. |
| 191 | + * ``` |
| 192 | + * it('should work', () => { |
| 193 | + * expect(() => new UintN<32>(2 ** 32)).toThrowError(`expected value <= ${2 ** 32 - 1}`) |
| 194 | + * }) |
| 195 | + * ``` |
| 196 | + */ |
| 197 | + if (this.isFunction && ts.isNewExpression(node)) { |
| 198 | + return new ExpressionVisitor(this.context, this.helper, node).result() |
124 | 199 | } |
125 | 200 | return node |
126 | 201 | } |
@@ -181,38 +256,55 @@ class ClassVisitor { |
181 | 256 | } |
182 | 257 |
|
183 | 258 | if (ts.isCallExpression(node)) { |
184 | | - let type = this.helper.resolveType(node) |
185 | | - |
186 | | - // `voted = LocalState<uint64>()` is resolved to FunctionPType with returnType LocalState<uint64> |
187 | | - if (type instanceof FunctionPType) type = type.returnType |
188 | | - |
189 | | - if (typeRegistry.isGeneric(type)) { |
190 | | - const info = getGenericTypeInfo(type) |
191 | | - return nodeFactory.captureGenericTypeInfo(ts.visitEachChild(node, this.visit, this.context), JSON.stringify(info)) |
192 | | - } |
| 259 | + return new ExpressionVisitor(this.context, this.helper, node).result() |
193 | 260 | } |
194 | 261 | return ts.visitEachChild(node, this.visit, this.context) |
195 | 262 | } |
196 | 263 | } |
197 | 264 |
|
| 265 | +const isGenericType = (type: PType): boolean => |
| 266 | + instanceOfAny( |
| 267 | + type, |
| 268 | + ARC4StructType, |
| 269 | + ARC4TupleType, |
| 270 | + BoxMapPType, |
| 271 | + BoxPType, |
| 272 | + DynamicArrayType, |
| 273 | + GlobalStateType, |
| 274 | + LocalStateType, |
| 275 | + StaticArrayType, |
| 276 | + UFixedNxMType, |
| 277 | + UintNType, |
| 278 | + ) |
| 279 | + |
198 | 280 | const getGenericTypeInfo = (type: PType): TypeInfo => { |
199 | | - let genericArgs: TypeInfo[] | Record<string, TypeInfo> | undefined = typeRegistry.isGeneric(type) |
200 | | - ? type.getGenericArgs().map(getGenericTypeInfo) |
201 | | - : undefined |
202 | | - |
203 | | - if (!genericArgs || !genericArgs.length) { |
204 | | - if (Object.hasOwn(type, 'items')) { |
205 | | - genericArgs = (type as DeliberateAny).items.map(getGenericTypeInfo) |
206 | | - } else if (Object.hasOwn(type, 'itemType')) { |
207 | | - genericArgs = [getGenericTypeInfo((type as DeliberateAny).itemType)] |
208 | | - } else if (Object.hasOwn(type, 'properties')) { |
209 | | - genericArgs = Object.fromEntries( |
210 | | - Object.entries((type as DeliberateAny).properties).map(([key, value]) => [key, getGenericTypeInfo(value as PType)]), |
211 | | - ) |
212 | | - } |
| 281 | + const genericArgs: TypeInfo[] | Record<string, TypeInfo> = [] |
| 282 | + |
| 283 | + if (instanceOfAny(type, LocalStateType, GlobalStateType, BoxPType)) { |
| 284 | + genericArgs.push(getGenericTypeInfo(type.contentType)) |
| 285 | + } else if (type instanceof BoxMapPType) { |
| 286 | + genericArgs.push(getGenericTypeInfo(type.keyType)) |
| 287 | + genericArgs.push(getGenericTypeInfo(type.contentType)) |
| 288 | + } else if (instanceOfAny(type, StaticArrayType, DynamicArrayType)) { |
| 289 | + genericArgs.push(getGenericTypeInfo(type.elementType)) |
| 290 | + } else if (type instanceof UFixedNxMType) { |
| 291 | + genericArgs.push({ name: type.n.toString() }) |
| 292 | + genericArgs.push({ name: type.m.toString() }) |
| 293 | + } else if (type instanceof UintNType) { |
| 294 | + genericArgs.push({ name: type.n.toString() }) |
| 295 | + } else if (type instanceof ARC4StructType) { |
| 296 | + genericArgs.push( |
| 297 | + ...Object.fromEntries( |
| 298 | + Object.entries(type.fields) |
| 299 | + .map(([key, value]) => [key, getGenericTypeInfo(value)]) |
| 300 | + .filter((x) => !!x), |
| 301 | + ), |
| 302 | + ) |
| 303 | + } else if (type instanceof ARC4TupleType) { |
| 304 | + genericArgs.push(...type.items.map(getGenericTypeInfo)) |
213 | 305 | } |
214 | 306 |
|
215 | | - const result: TypeInfo = { name: type?.name ?? 'unknown' } |
| 307 | + const result: TypeInfo = { name: type?.name ?? type?.toString() ?? 'unknown' } |
216 | 308 | if (genericArgs && (genericArgs.length || Object.keys(genericArgs).length)) { |
217 | 309 | result.genericArgs = genericArgs |
218 | 310 | } |
|
0 commit comments