Skip to content

Commit 7e17ce1

Browse files
authored
support ADT style enums (#2)
1 parent 6e6e200 commit 7e17ce1

File tree

4 files changed

+112
-4
lines changed

4 files changed

+112
-4
lines changed

build.sbt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ libraryDependencies ++= Seq(
2121
"com.github.swagger-akka-http" %% "swagger-scala-module" % "2.8.2",
2222
"com.fasterxml.jackson.module" %% "jackson-module-scala" % "2.14.1",
2323
"org.scalatest" %% "scalatest" % "3.2.14" % Test,
24-
"org.slf4j" % "slf4j-simple" % "2.0.5" % Test
24+
"org.slf4j" % "slf4j-simple" % "2.0.6" % Test
2525
)
2626

2727
homepage := Some(new URL("https://github.com/swagger-akka-http/swagger-scala3-enum-module"))

src/main/scala/com/github/swagger/scala3enum/converter/SwaggerScala3EnumModelConverter.scala

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,20 +10,25 @@ import io.swagger.v3.oas.annotations.media.{ArraySchema, Schema as SchemaAnnotat
1010
import io.swagger.v3.oas.models.media.Schema
1111

1212
import java.lang.annotation.Annotation
13+
import java.lang.reflect.InvocationTargetException
1314
import java.util.Iterator
1415
import scala.reflect.Enum
16+
import scala.util.Try
1517

1618
class SwaggerScala3EnumModelConverter extends ModelResolver(Json.mapper()) {
1719
private val enumEntryClass = classOf[Enum]
20+
private val IntClass = classOf[Int]
1821

1922
override def resolve(annotatedType: AnnotatedType, context: ModelConverterContext, chain: Iterator[ModelConverter]): Schema[_] = {
2023
val javaType = _mapper.constructType(annotatedType.getType)
2124
val cls = javaType.getRawClass
2225
if (isEnum(cls)) {
2326
val sp: Schema[String] = PrimitiveType.STRING.createProperty().asInstanceOf[Schema[String]]
2427
setRequired(annotatedType)
25-
getValues(cls).foreach { v =>
26-
sp.addEnumItemObject(v)
28+
tryValues(cls).toOption.orElse(matchBasedOnOrdinals(cls)).map { values =>
29+
values.foreach { v =>
30+
sp.addEnumItemObject(v)
31+
}
2732
}
2833
nullSafeList(annotatedType.getCtxAnnotations).foreach {
2934
case p: Parameter => {
@@ -63,12 +68,48 @@ class SwaggerScala3EnumModelConverter extends ModelResolver(Json.mapper()) {
6368

6469
private def isEnum(cls: Class[_]): Boolean = enumEntryClass.isAssignableFrom(cls)
6570

66-
private def getValues(cls: Class[_]): Seq[String] = {
71+
private def tryValues(cls: Class[_]): Try[Seq[String]] = Try {
6772
val enumCompanion = Class.forName(cls.getName + "$").getField("MODULE$").get(null)
6873
val enumArray = enumCompanion.getClass.getDeclaredMethod("values").invoke(enumCompanion).asInstanceOf[Array[Enum]]
6974
enumArray.sortBy(_.ordinal).map(_.toString).toSeq
7075
}
7176

77+
private def matchBasedOnOrdinals(clz: Class[_]): Option[Seq[String]] = {
78+
val className = clz.getName
79+
val companionObjectClassOption = if (className.endsWith("$")) {
80+
Some(clz)
81+
} else {
82+
Try(Class.forName(className + "$")).toOption
83+
}
84+
companionObjectClassOption.flatMap { companionObjectClass =>
85+
Try(companionObjectClass.getField("MODULE$")).toOption.flatMap { moduleField =>
86+
val instance = moduleField.get(None.orNull)
87+
Try(clz.getMethod("fromOrdinal", IntClass)).toOption.map { method =>
88+
var i = 0
89+
var matched: Seq[String] = Seq.empty[String]
90+
var complete = false
91+
while (!complete) {
92+
try {
93+
val enumValue = method.invoke(instance, i)
94+
matched = matched :+ enumValue.toString
95+
} catch {
96+
case _: NoSuchElementException => complete = true
97+
case itex: InvocationTargetException => {
98+
Option(itex.getCause) match {
99+
case Some(e) if e.isInstanceOf[NoSuchElementException] => complete = true
100+
case Some(e) => throw e
101+
case _ => throw itex
102+
}
103+
}
104+
}
105+
i += 1
106+
}
107+
matched
108+
}
109+
}
110+
}
111+
}
112+
72113
private def setRequired(annotatedType: AnnotatedType): Unit = annotatedType match {
73114
case _: AnnotatedTypeForOption => // not required
74115
case _ => {
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
package com.github.swagger.scala3enum.converter.adt
2+
3+
enum Color(val rgb: Int):
4+
case Red extends Color(0xFF0000)
5+
case Green extends Color(0x00FF00)
6+
case Blue extends Color(0x0000FF)
7+
case Mix(mix: Int) extends Color(mix)
8+
9+
case class ColorSet(set: Set[Color])
10+
11+
case class Car(make: String, color: Color)
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
package com.github.swagger.scala3enum.converter.adt
2+
3+
import io.swagger.v3.core.converter.ModelConverters
4+
import io.swagger.v3.oas.models.media.{ArraySchema, Schema, StringSchema}
5+
import org.scalatest.OptionValues
6+
import org.scalatest.matchers.should.Matchers
7+
import org.scalatest.wordspec.AnyWordSpec
8+
9+
import scala.jdk.CollectionConverters.*
10+
import scala.reflect.Enum
11+
12+
class SwaggerAdtConverterSpec extends AnyWordSpec with Matchers with OptionValues {
13+
"SwaggerScala3EnumModelConverter" should {
14+
"get model for Car" in {
15+
val converter = ModelConverters.getInstance()
16+
val schemas = converter.readAll(classOf[Car]).asScala.toMap
17+
val model = findModel(schemas, "Car")
18+
model should be(defined)
19+
model.get.getProperties should not be (null)
20+
val field = model.value.getProperties.get("color")
21+
field shouldBe a[StringSchema]
22+
nullSafeList(field.asInstanceOf[StringSchema].getEnum) shouldEqual Seq("Red", "Green", "Blue")
23+
nullSafeList(field.getRequired) shouldBe empty
24+
nullSafeList(model.value.getRequired) shouldEqual Seq("color", "make")
25+
}
26+
"get model for ColorSet" in {
27+
val converter = ModelConverters.getInstance()
28+
val schemas = converter.readAll(classOf[ColorSet]).asScala.toMap
29+
val model = findModel(schemas, "ColorSet")
30+
model should be (defined)
31+
model.get.getProperties should not be (null)
32+
val field = model.value.getProperties.get("set")
33+
field shouldBe an [ArraySchema]
34+
val arraySchema = field.asInstanceOf[ArraySchema]
35+
nullSafeList(arraySchema.getItems.getEnum) shouldEqual Seq("Red", "Green", "Blue")
36+
nullSafeList(arraySchema.getRequired) shouldBe empty
37+
nullSafeList(model.value.getRequired) shouldEqual Seq("set")
38+
}
39+
}
40+
41+
private def findModel(schemas: Map[String, Schema[_]], name: String): Option[Schema[_]] = {
42+
schemas.get(name) match {
43+
case Some(m) => Some(m)
44+
case None =>
45+
schemas.keys.find { case k => k.startsWith(name) } match {
46+
case Some(key) => schemas.get(key)
47+
case None => schemas.values.headOption
48+
}
49+
}
50+
}
51+
52+
private def nullSafeList[T](list: java.util.List[T]): List[T] = Option(list) match {
53+
case None => List[T]()
54+
case Some(l) => l.asScala.toList
55+
}
56+
}

0 commit comments

Comments
 (0)