1919import './flags_webgl' ;
2020
2121import * as tf from '@tensorflow/tfjs-core' ;
22- import { backend_util , BackendValues , buffer , DataId , DataStorage , DataToGPUWebGLOption , DataType , DataValues , engine , env , GPUData , kernel_impls , KernelBackend , MemoryInfo , NumericDataType , Rank , RecursiveArray , scalar , ShapeMap , Tensor , Tensor2D , TensorBuffer , TensorInfo , tidy , TimingInfo , TypedArray , util } from '@tensorflow/tfjs-core' ;
23-
22+ import { backend_util , BackendValues , buffer , DataId , DataStorage , DataToGPUWebGLOption , DataType , DataValues , engine , env , GPUData , kernel_impls , KernelBackend , MemoryInfo , nextFrame , NumericDataType , Rank , RecursiveArray , scalar , ShapeMap , Tensor , Tensor2D , TensorBuffer , TensorInfo , tidy , TimingInfo , TypedArray , util } from '@tensorflow/tfjs-core' ;
2423import { getWebGLContext } from './canvas_util' ;
2524import { DecodeMatrixProgram } from './decode_matrix_gpu' ;
2625import { DecodeMatrixPackedProgram } from './decode_matrix_packed_gpu' ;
@@ -30,7 +29,7 @@ import {EncodeMatrixProgram} from './encode_matrix_gpu';
3029import { EncodeMatrixPackedProgram } from './encode_matrix_packed_gpu' ;
3130import { GPGPUContext } from './gpgpu_context' ;
3231import * as gpgpu_math from './gpgpu_math' ;
33- import { GPGPUBinary , GPGPUProgram , TensorData } from './gpgpu_math' ;
32+ import { getUniformLocations , GPGPUBinary , GPGPUProgram , TensorData } from './gpgpu_math' ;
3433import { simpleAbsImplCPU } from './kernel_utils/shared' ;
3534import { PackProgram } from './pack_gpu' ;
3635import { ReshapePackedProgram } from './reshape_packed_gpu' ;
@@ -549,15 +548,16 @@ export class MathBackendWebGL extends KernelBackend {
549548 } ;
550549
551550 return ( async ( ) => {
552- if ( env ( )
553- . getNumber ( 'WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE' ) > 0 ) {
551+ if ( env ( ) . getNumber ( 'WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE' ) >
552+ 0 ) {
554553 const kernelMs = await Promise . all ( flattenedActiveTimerQueries ) ;
555554
556555 res [ 'kernelMs' ] = util . sum ( kernelMs ) ;
557556 res [ 'getExtraProfileInfo' ] = ( ) =>
558- kernelMs . map ( ( d , i ) => ( { name : flattenedActiveTimerNames [ i ] , ms : d } ) )
559- . map ( d => `${ d . name } : ${ d . ms } ` )
560- . join ( ', ' ) ;
557+ kernelMs
558+ . map ( ( d , i ) => ( { name : flattenedActiveTimerNames [ i ] , ms : d } ) )
559+ . map ( d => `${ d . name } : ${ d . ms } ` )
560+ . join ( ', ' ) ;
561561 } else {
562562 res [ 'kernelMs' ] = {
563563 error : 'WebGL query timers are not supported in this environment.'
@@ -949,8 +949,10 @@ export class MathBackendWebGL extends KernelBackend {
949949 query = this . startTimer ( ) ;
950950 }
951951
952- gpgpu_math . runProgram (
953- this . gpgpu , binary , inputsData , outputData , customUniformValues ) ;
952+ if ( ! env ( ) . get ( 'ENGINE_COMPILE_ONLY' ) ) {
953+ gpgpu_math . runProgram (
954+ this . gpgpu , binary , inputsData , outputData , customUniformValues ) ;
955+ }
954956
955957 dataToDispose . forEach ( info => this . disposeIntermediateTensorInfo ( info ) ) ;
956958
@@ -1130,16 +1132,21 @@ export class MathBackendWebGL extends KernelBackend {
11301132
11311133 // Have the original texture assume the identity of the encoded output.
11321134 const outputTexData = this . texData . get ( encodedOutputTarget . dataId ) ;
1133- texData . texture = outputTexData . texture ;
11341135 texData . texShape = outputTexData . texShape ;
11351136 texData . isPacked = outputTexData . isPacked ;
11361137 texData . usage = outputTexData . usage ;
11371138
1139+ if ( ! env ( ) . get ( 'ENGINE_COMPILE_ONLY' ) ) {
1140+ texData . texture = outputTexData . texture ;
1141+ // Once uploaded, don't store the values on cpu.
1142+ texData . values = null ;
1143+ this . texData . delete ( encodedOutputTarget . dataId ) ;
1144+ } else {
1145+ this . disposeData ( encodedOutputTarget . dataId ) ;
1146+ }
1147+
11381148 this . disposeIntermediateTensorInfo ( tempDenseInputHandle ) ;
1139- this . texData . delete ( encodedOutputTarget . dataId ) ;
11401149
1141- // Once uploaded, don't store the values on cpu.
1142- texData . values = null ;
11431150 if ( shouldTimeProgram ) {
11441151 this . uploadWaitMs += util . now ( ) - start ;
11451152 }
@@ -1180,6 +1187,87 @@ export class MathBackendWebGL extends KernelBackend {
11801187 private computeBytes ( shape : [ number , number ] , dtype : DataType ) {
11811188 return shape [ 0 ] * shape [ 1 ] * util . bytesPerElement ( dtype ) ;
11821189 }
1190+
1191+ checkCompileCompletion ( ) {
1192+ for ( const [ , binary ] of Object . entries ( this . binaryCache ) ) {
1193+ this . checkCompletion_ ( binary ) ;
1194+ }
1195+ }
1196+
1197+ async checkCompileCompletionAsync ( ) : Promise < boolean [ ] > {
1198+ const ps = [ ] ;
1199+ if ( this . gpgpu . parallelCompilationExtension ) {
1200+ for ( const [ , binary ] of Object . entries ( this . binaryCache ) ) {
1201+ ps . push ( this . checkCompletionAsync_ ( binary ) ) ;
1202+ }
1203+ return Promise . all ( ps ) ;
1204+ } else {
1205+ for ( const [ , binary ] of Object . entries ( this . binaryCache ) ) {
1206+ const p : Promise < boolean > = new Promise ( ( resolve ) => {
1207+ try {
1208+ this . checkCompletion_ ( binary ) ;
1209+ resolve ( true ) ;
1210+ } catch ( error ) {
1211+ throw error ;
1212+ }
1213+ } ) ;
1214+ ps . push ( p ) ;
1215+ }
1216+ return Promise . all ( ps ) ;
1217+ }
1218+ }
1219+
1220+ private async checkCompletionAsync_ ( binary : GPGPUBinary ) : Promise < boolean > {
1221+ if ( this . gpgpu . gl . getProgramParameter (
1222+ binary . webGLProgram ,
1223+ this . gpgpu . parallelCompilationExtension . COMPLETION_STATUS_KHR ) ) {
1224+ return this . checkCompletion_ ( binary ) ;
1225+ } else {
1226+ await nextFrame ( ) ;
1227+ return this . checkCompletionAsync_ ( binary ) ;
1228+ }
1229+ }
1230+
1231+ private checkCompletion_ ( binary : GPGPUBinary ) : boolean {
1232+ if ( this . gpgpu . gl . getProgramParameter (
1233+ binary . webGLProgram , this . gpgpu . gl . LINK_STATUS ) === false ) {
1234+ console . log ( this . gpgpu . gl . getProgramInfoLog ( binary . webGLProgram ) ) ;
1235+ if ( this . gpgpu . gl . getShaderParameter (
1236+ binary . fragmentShader , this . gpgpu . gl . COMPILE_STATUS ) === false ) {
1237+ webgl_util . logShaderSourceAndInfoLog (
1238+ binary . source ,
1239+ this . gpgpu . gl . getShaderInfoLog ( binary . fragmentShader ) ) ;
1240+ throw new Error ( 'Failed to compile fragment shader.' ) ;
1241+ }
1242+ throw new Error ( 'Failed to link vertex and fragment shaders.' ) ;
1243+ }
1244+ return true ;
1245+ }
1246+
1247+ getUniformLocations ( ) {
1248+ for ( const [ , binary ] of Object . entries ( this . binaryCache ) ) {
1249+ const {
1250+ uniformLocations,
1251+ customUniformLocations,
1252+ infLoc,
1253+ nanLoc,
1254+ inShapesLocations,
1255+ inTexShapesLocations,
1256+ outShapeLocation,
1257+ outShapeStridesLocation,
1258+ outTexShapeLocation
1259+ } = getUniformLocations ( this . gpgpu , binary . program , binary . webGLProgram ) ;
1260+ binary . uniformLocations = uniformLocations ;
1261+ binary . customUniformLocations = customUniformLocations ;
1262+ binary . infLoc = infLoc ;
1263+ binary . nanLoc = nanLoc ;
1264+ binary . inShapesLocations = inShapesLocations ;
1265+ binary . inTexShapesLocations = inTexShapesLocations ;
1266+ binary . outShapeLocation = outShapeLocation ;
1267+ binary . outShapeStridesLocation = outShapeStridesLocation ;
1268+ binary . outTexShapeLocation = outTexShapeLocation ;
1269+ }
1270+ }
11831271}
11841272
11851273function float32ToTypedArray < D extends NumericDataType > (
0 commit comments