@@ -2834,10 +2834,15 @@ class SYCLKernelNameTypeVisitor
2834
2834
Sema &S;
2835
2835
SourceLocation KernelInvocationFuncLoc;
2836
2836
using InnerTypeVisitor = TypeVisitor<SYCLKernelNameTypeVisitor>;
2837
- using InnerTAVisitor =
2837
+ using InnerTemplArgVisitor =
2838
2838
ConstTemplateArgumentVisitor<SYCLKernelNameTypeVisitor>;
2839
2839
bool IsInvalid = false ;
2840
2840
2841
+ void VisitTemplateArgs (ArrayRef<TemplateArgument> Args) {
2842
+ for (auto &A : Args)
2843
+ Visit (A);
2844
+ }
2845
+
2841
2846
public:
2842
2847
SYCLKernelNameTypeVisitor (Sema &S, SourceLocation KernelInvocationFuncLoc)
2843
2848
: S(S), KernelInvocationFuncLoc(KernelInvocationFuncLoc) {}
@@ -2848,15 +2853,19 @@ class SYCLKernelNameTypeVisitor
2848
2853
if (T.isNull ())
2849
2854
return ;
2850
2855
const CXXRecordDecl *RD = T->getAsCXXRecordDecl ();
2851
- if (!RD)
2856
+ if (!RD) {
2857
+ if (T->isNullPtrType ()) {
2858
+ S.Diag (KernelInvocationFuncLoc, diag::err_sycl_kernel_incorrectly_named)
2859
+ << /* kernel name cannot be a type in the std namespace */ 3 ;
2860
+ IsInvalid = true ;
2861
+ }
2852
2862
return ;
2863
+ }
2853
2864
// If KernelNameType has template args visit each template arg via
2854
2865
// ConstTemplateArgumentVisitor
2855
2866
if (const auto *TSD = dyn_cast<ClassTemplateSpecializationDecl>(RD)) {
2856
- const TemplateArgumentList &Args = TSD->getTemplateArgs ();
2857
- for (unsigned I = 0 ; I < Args.size (); I++) {
2858
- Visit (Args[I]);
2859
- }
2867
+ ArrayRef<TemplateArgument> Args = TSD->getTemplateArgs ().asArray ();
2868
+ VisitTemplateArgs (Args);
2860
2869
} else {
2861
2870
InnerTypeVisitor::Visit (T.getTypePtr ());
2862
2871
}
@@ -2865,7 +2874,7 @@ class SYCLKernelNameTypeVisitor
2865
2874
void Visit (const TemplateArgument &TA) {
2866
2875
if (TA.isNull ())
2867
2876
return ;
2868
- InnerTAVisitor ::Visit (TA);
2877
+ InnerTemplArgVisitor ::Visit (TA);
2869
2878
}
2870
2879
2871
2880
void VisitEnumType (const EnumType *T) {
@@ -2886,22 +2895,31 @@ class SYCLKernelNameTypeVisitor
2886
2895
void VisitTagDecl (const TagDecl *Tag) {
2887
2896
bool UnnamedLambdaEnabled =
2888
2897
S.getASTContext ().getLangOpts ().SYCLUnnamedLambda ;
2889
- if (! Tag->getDeclContext ()-> isTranslationUnit () &&
2890
- !isa<NamespaceDecl>(Tag-> getDeclContext ()) && !UnnamedLambdaEnabled) {
2891
- const bool KernelNameIsMissing = Tag-> getName (). empty ( );
2892
- if (KernelNameIsMissing ) {
2898
+ const DeclContext *DeclCtx = Tag->getDeclContext ();
2899
+ if (DeclCtx && !UnnamedLambdaEnabled) {
2900
+ auto *NameSpace = dyn_cast_or_null<NamespaceDecl>(DeclCtx );
2901
+ if (NameSpace && NameSpace-> isStdNamespace () ) {
2893
2902
S.Diag (KernelInvocationFuncLoc, diag::err_sycl_kernel_incorrectly_named)
2894
- << /* kernel name is missing */ 0 ;
2903
+ << /* kernel name cannot be a type in the std namespace */ 3 ;
2895
2904
IsInvalid = true ;
2896
- } else {
2905
+ return ;
2906
+ }
2907
+ if (!DeclCtx->isTranslationUnit () && !isa<NamespaceDecl>(DeclCtx)) {
2908
+ const bool KernelNameIsMissing = Tag->getName ().empty ();
2909
+ if (KernelNameIsMissing) {
2910
+ S.Diag (KernelInvocationFuncLoc,
2911
+ diag::err_sycl_kernel_incorrectly_named)
2912
+ << /* kernel name is missing */ 0 ;
2913
+ IsInvalid = true ;
2914
+ return ;
2915
+ }
2897
2916
if (Tag->isCompleteDefinition ()) {
2898
2917
S.Diag (KernelInvocationFuncLoc,
2899
2918
diag::err_sycl_kernel_incorrectly_named)
2900
2919
<< /* kernel name is not globally-visible */ 1 ;
2901
2920
IsInvalid = true ;
2902
2921
} else
2903
2922
S.Diag (KernelInvocationFuncLoc, diag::warn_sycl_implicit_decl);
2904
-
2905
2923
S.Diag (Tag->getSourceRange ().getBegin (), diag::note_previous_decl)
2906
2924
<< Tag->getName ();
2907
2925
}
@@ -2932,6 +2950,10 @@ class SYCLKernelNameTypeVisitor
2932
2950
VisitEnumType (ET);
2933
2951
}
2934
2952
}
2953
+
2954
+ void VisitPackTemplateArgument (const TemplateArgument &TA) {
2955
+ VisitTemplateArgs (TA.getPackAsArray ());
2956
+ }
2935
2957
};
2936
2958
2937
2959
void Sema::CheckSYCLKernelCall (FunctionDecl *KernelFunc, SourceRange CallLoc,
@@ -3337,12 +3359,6 @@ void SYCLIntegrationHeader::emitFwdDecl(raw_ostream &O, const Decl *D,
3337
3359
break ;
3338
3360
}
3339
3361
3340
- if (NS->isStdNamespace ()) {
3341
- Diag.Report (KernelLocation, diag::err_sycl_kernel_incorrectly_named)
3342
- << /* name cannot be a type in the std namespace */ 3 ;
3343
- return ;
3344
- }
3345
-
3346
3362
++NamespaceCnt;
3347
3363
const StringRef NSInlinePrefix = NS->isInline () ? " inline " : " " ;
3348
3364
NSStr.insert (
@@ -3426,9 +3442,6 @@ void SYCLIntegrationHeader::emitForwardClassDecls(
3426
3442
const CXXRecordDecl *RD = T->getAsCXXRecordDecl ();
3427
3443
3428
3444
if (!RD) {
3429
- if (T->isNullPtrType ())
3430
- Diag.Report (KernelLocation, diag::err_sycl_kernel_incorrectly_named)
3431
- << /* name cannot be a type in the std namespace */ 3 ;
3432
3445
3433
3446
return ;
3434
3447
}
0 commit comments