@@ -5,6 +5,7 @@ import org.utbot.intellij.plugin.ui.utils.PsiElementHandler
55import com.intellij.openapi.actionSystem.AnAction
66import com.intellij.openapi.actionSystem.AnActionEvent
77import com.intellij.openapi.actionSystem.CommonDataKeys
8+ import com.intellij.openapi.actionSystem.PlatformDataKeys
89import com.intellij.openapi.editor.Editor
910import com.intellij.openapi.module.ModuleUtil
1011import com.intellij.openapi.project.Project
@@ -19,19 +20,21 @@ import org.jetbrains.kotlin.idea.core.util.toPsiDirectory
1920import org.jetbrains.kotlin.idea.core.util.toPsiFile
2021import org.utbot.intellij.plugin.util.extractFirstLevelMembers
2122import java.util.*
23+ import org.jetbrains.kotlin.j2k.getContainingClass
24+ import org.jetbrains.kotlin.utils.addIfNotNull
2225
2326class GenerateTestsAction : AnAction () {
2427 override fun actionPerformed (e : AnActionEvent ) {
2528 val project = e.project ? : return
26- val (srcClasses, focusedMethod , extractMembersFromSrcClasses) = getPsiTargets(e) ? : return
27- UtTestsDialogProcessor .createDialogAndGenerateTests(project, srcClasses, extractMembersFromSrcClasses, focusedMethod )
29+ val (srcClasses, focusedMethods , extractMembersFromSrcClasses) = getPsiTargets(e) ? : return
30+ UtTestsDialogProcessor .createDialogAndGenerateTests(project, srcClasses, extractMembersFromSrcClasses, focusedMethods )
2831 }
2932
3033 override fun update (e : AnActionEvent ) {
3134 e.presentation.isEnabled = getPsiTargets(e) != null
3235 }
3336
34- private fun getPsiTargets (e : AnActionEvent ): Triple <Set <PsiClass >, MemberInfo? , Boolean>? {
37+ private fun getPsiTargets (e : AnActionEvent ): Triple <Set <PsiClass >, Set< MemberInfo> , Boolean>? {
3538 val project = e.project ? : return null
3639 val editor = e.getData(CommonDataKeys .EDITOR )
3740 if (editor != null ) {
@@ -56,19 +59,19 @@ class GenerateTestsAction : AnAction() {
5659 return null
5760 }
5861
59- return Triple (setOf (srcClass), focusedMethod, true )
62+ return Triple (setOf (srcClass), if ( focusedMethod != null ) setOf (focusedMethod) else emptySet() , true )
6063 }
6164 } else {
6265 // The action is being called from 'Project' tool window
6366 val srcClasses = mutableSetOf<PsiClass >()
64- var selectedMethod : MemberInfo ? = null
67+ val selectedMethods = mutableSetOf< MemberInfo >()
6568 var extractMembersFromSrcClasses = false
66- val element = e.getData(CommonDataKeys .PSI_ELEMENT ) ? : return null
69+ val element = e.getData(CommonDataKeys .PSI_ELEMENT )
6770 if (element is PsiFileSystemItem ) {
6871 e.getData(CommonDataKeys .VIRTUAL_FILE_ARRAY )?.let {
6972 srcClasses + = getAllClasses(project, it)
7073 }
71- } else {
74+ } else if (element is PsiElement ) {
7275 val file = element.containingFile ? : return null
7376 val psiElementHandler = PsiElementHandler .makePsiElementHandler(file)
7477
@@ -81,11 +84,29 @@ class GenerateTestsAction : AnAction() {
8184 }
8285
8386 if (element is PsiMethod ) {
84- selectedMethod = MemberInfo (element)
87+ selectedMethods.add(MemberInfo (element))
88+ }
89+ }
90+ } else {
91+ val someSelection = e.getData(PlatformDataKeys .SELECTED_ITEMS )? : return null
92+ someSelection.forEach {
93+ when (it) {
94+ is PsiFileSystemItem -> srcClasses + = getAllClasses(project, arrayOf(it.virtualFile))
95+ is PsiClass -> srcClasses.add(it)
96+ is PsiElement -> {
97+ srcClasses.addIfNotNull(it.getContainingClass())
98+ if (it is PsiMethod ) {
99+ selectedMethods.add(MemberInfo (it))
100+ extractMembersFromSrcClasses = true
101+ }
102+ }
85103 }
86104 }
87105 }
88106 srcClasses.removeIf { it.isInterface }
107+ if (srcClasses.size > 1 ) {
108+ extractMembersFromSrcClasses = false
109+ }
89110 var commonSourceRoot = null as VirtualFile ?
90111 for (srcClass in srcClasses) {
91112 if (commonSourceRoot == null ) {
@@ -100,7 +121,7 @@ class GenerateTestsAction : AnAction() {
100121 .filter { folder -> ! folder.rootType.isForTests && folder.file == commonSourceRoot}
101122 .findAny().isPresent ) return null
102123
103- return Triple (srcClasses.toSet(), selectedMethod , extractMembersFromSrcClasses)
124+ return Triple (srcClasses.toSet(), selectedMethods.toSet() , extractMembersFromSrcClasses)
104125 }
105126 return null
106127 }
0 commit comments