@@ -421,3 +421,128 @@ void DerivativeFunctionTypeError::log(raw_ostream &OS) const {
421
421
}
422
422
}
423
423
}
424
+
425
+ bool swift::operator ==(const TangentPropertyInfo::Error &lhs,
426
+ const TangentPropertyInfo::Error &rhs) {
427
+ if (lhs.kind != rhs.kind )
428
+ return false ;
429
+ switch (lhs.kind ) {
430
+ case TangentPropertyInfo::Error::Kind::NoDerivativeOriginalProperty:
431
+ case TangentPropertyInfo::Error::Kind::NominalParentNotDifferentiable:
432
+ case TangentPropertyInfo::Error::Kind::OriginalPropertyNotDifferentiable:
433
+ case TangentPropertyInfo::Error::Kind::ParentTangentVectorNotStruct:
434
+ case TangentPropertyInfo::Error::Kind::TangentPropertyNotFound:
435
+ case TangentPropertyInfo::Error::Kind::TangentPropertyNotStored:
436
+ return true ;
437
+ case TangentPropertyInfo::Error::Kind::TangentPropertyWrongType:
438
+ return lhs.getType ()->isEqual (rhs.getType ());
439
+ }
440
+ }
441
+
442
+ void swift::simple_display (llvm::raw_ostream &os, TangentPropertyInfo info) {
443
+ os << " { " ;
444
+ os << " tangent property: "
445
+ << (info.tangentProperty ? info.tangentProperty ->printRef () : " null" );
446
+ if (info.error ) {
447
+ os << " , error: " ;
448
+ switch (info.error ->kind ) {
449
+ case TangentPropertyInfo::Error::Kind::NoDerivativeOriginalProperty:
450
+ os << " '@noDerivative' original property has no tangent property" ;
451
+ break ;
452
+ case TangentPropertyInfo::Error::Kind::NominalParentNotDifferentiable:
453
+ os << " nominal parent does not conform to 'Differentiable'" ;
454
+ break ;
455
+ case TangentPropertyInfo::Error::Kind::OriginalPropertyNotDifferentiable:
456
+ os << " original property type does not conform to 'Differentiable'" ;
457
+ break ;
458
+ case TangentPropertyInfo::Error::Kind::ParentTangentVectorNotStruct:
459
+ os << " 'TangentVector' type is not a struct" ;
460
+ break ;
461
+ case TangentPropertyInfo::Error::Kind::TangentPropertyNotFound:
462
+ os << " 'TangentVector' struct does not have stored property with the "
463
+ " same name as the original property" ;
464
+ break ;
465
+ case TangentPropertyInfo::Error::Kind::TangentPropertyWrongType:
466
+ os << " tangent property's type is not equal to the original property's "
467
+ " 'TangentVector' type" ;
468
+ break ;
469
+ case TangentPropertyInfo::Error::Kind::TangentPropertyNotStored:
470
+ os << " 'TangentVector' property '" << info.tangentProperty ->getName ()
471
+ << " ' is not a stored property" ;
472
+ break ;
473
+ }
474
+ }
475
+ os << " }" ;
476
+ }
477
+
478
+ TangentPropertyInfo
479
+ TangentStoredPropertyRequest::evaluate (Evaluator &evaluator,
480
+ VarDecl *originalField) const {
481
+ assert (originalField->hasStorage () && originalField->isInstanceMember () &&
482
+ " Expected stored property" );
483
+ auto *parentDC = originalField->getDeclContext ();
484
+ assert (parentDC->isTypeContext ());
485
+ auto parentType = parentDC->getDeclaredTypeInContext ();
486
+ auto *moduleDecl = originalField->getModuleContext ();
487
+ auto parentTan = parentType->getAutoDiffTangentSpace (
488
+ LookUpConformanceInModule (moduleDecl));
489
+ // Error if parent nominal type does not conform to `Differentiable`.
490
+ if (!parentTan) {
491
+ return TangentPropertyInfo (
492
+ TangentPropertyInfo::Error::Kind::NominalParentNotDifferentiable);
493
+ }
494
+ // Error if original stored property is `@noDerivative`.
495
+ if (originalField->getAttrs ().hasAttribute <NoDerivativeAttr>()) {
496
+ return TangentPropertyInfo (
497
+ TangentPropertyInfo::Error::Kind::NoDerivativeOriginalProperty);
498
+ }
499
+ // Error if original property's type does not conform to `Differentiable`.
500
+ auto originalFieldTan = originalField->getType ()->getAutoDiffTangentSpace (
501
+ LookUpConformanceInModule (moduleDecl));
502
+ if (!originalFieldTan) {
503
+ return TangentPropertyInfo (
504
+ TangentPropertyInfo::Error::Kind::OriginalPropertyNotDifferentiable);
505
+ }
506
+ auto parentTanType = parentTan->getType ();
507
+ auto *parentTanStruct = parentTanType->getStructOrBoundGenericStruct ();
508
+ // Error if parent `TangentVector` is not a struct.
509
+ if (!parentTanStruct) {
510
+ return TangentPropertyInfo (
511
+ TangentPropertyInfo::Error::Kind::ParentTangentVectorNotStruct);
512
+ }
513
+ // Find the corresponding field in the tangent space.
514
+ VarDecl *tanField = nullptr ;
515
+ // If `TangentVector` is the original struct, then the tangent property is the
516
+ // original property.
517
+ if (parentTanStruct == parentDC->getSelfStructDecl ()) {
518
+ tanField = originalField;
519
+ }
520
+ // Otherwise, look up the field by name.
521
+ else {
522
+ auto tanFieldLookup =
523
+ parentTanStruct->lookupDirect (originalField->getName ());
524
+ llvm::erase_if (tanFieldLookup,
525
+ [](ValueDecl *v) { return !isa<VarDecl>(v); });
526
+ // Error if tangent property could not be found.
527
+ if (tanFieldLookup.empty ()) {
528
+ return TangentPropertyInfo (
529
+ TangentPropertyInfo::Error::Kind::TangentPropertyNotFound);
530
+ }
531
+ tanField = cast<VarDecl>(tanFieldLookup.front ());
532
+ }
533
+ // Error if tangent property's type is not equal to the original property's
534
+ // `TangentVector` type.
535
+ auto originalFieldTanType = originalFieldTan->getType ();
536
+ if (!originalFieldTanType->isEqual (tanField->getType ())) {
537
+ return TangentPropertyInfo (
538
+ TangentPropertyInfo::Error::Kind::TangentPropertyWrongType,
539
+ originalFieldTanType);
540
+ }
541
+ // Error if tangent property is not a stored property.
542
+ if (!tanField->hasStorage ()) {
543
+ return TangentPropertyInfo (
544
+ TangentPropertyInfo::Error::Kind::TangentPropertyNotStored);
545
+ }
546
+ // Otherwise, tangent property is valid.
547
+ return TangentPropertyInfo (tanField);
548
+ }
0 commit comments