1+ using OpenCVForUnity . CoreModule ;
2+ using OpenCVForUnity . DnnModule ;
3+ using OpenCVForUnity . ImgprocModule ;
4+ using System ;
5+ using System . Collections . Generic ;
6+ using System . Linq ;
7+ using System . Runtime . InteropServices ;
8+ using System . Text ;
9+ using UnityEngine ;
10+ using OpenCVRect = OpenCVForUnity . CoreModule . Rect ;
11+
12+ namespace YOLOv8WithOpenCVForUnity
13+ {
14+
15+ public class YOLOv8ClassPredictor
16+ {
17+ Size input_size ;
18+ int backend ;
19+ int target ;
20+
21+ Net classification_net ;
22+ List < string > classNames ;
23+
24+ List < Scalar > palette ;
25+
26+ Mat input_sizeMat ;
27+
28+ Mat getDataMat ;
29+
30+ public YOLOv8ClassPredictor ( string modelFilepath , string classesFilepath , Size inputSize , int backend = Dnn . DNN_BACKEND_OPENCV , int target = Dnn . DNN_TARGET_CPU )
31+ {
32+ // initialize
33+ if ( ! string . IsNullOrEmpty ( modelFilepath ) )
34+ {
35+ classification_net = Dnn . readNet ( modelFilepath ) ;
36+ }
37+
38+ if ( ! string . IsNullOrEmpty ( classesFilepath ) )
39+ {
40+ classNames = readClassNames ( classesFilepath ) ;
41+ }
42+
43+ input_size = new Size ( inputSize . width > 0 ? inputSize . width : 224 , inputSize . height > 0 ? inputSize . height : 224 ) ;
44+ this . backend = backend ;
45+ this . target = target ;
46+
47+ classification_net . setPreferableBackend ( this . backend ) ;
48+ classification_net . setPreferableTarget ( this . target ) ;
49+
50+ palette = new List < Scalar > ( ) ;
51+ palette . Add ( new Scalar ( 255 , 56 , 56 , 255 ) ) ;
52+ palette . Add ( new Scalar ( 255 , 157 , 151 , 255 ) ) ;
53+ palette . Add ( new Scalar ( 255 , 112 , 31 , 255 ) ) ;
54+ palette . Add ( new Scalar ( 255 , 178 , 29 , 255 ) ) ;
55+ palette . Add ( new Scalar ( 207 , 210 , 49 , 255 ) ) ;
56+ palette . Add ( new Scalar ( 72 , 249 , 10 , 255 ) ) ;
57+ palette . Add ( new Scalar ( 146 , 204 , 23 , 255 ) ) ;
58+ palette . Add ( new Scalar ( 61 , 219 , 134 , 255 ) ) ;
59+ palette . Add ( new Scalar ( 26 , 147 , 52 , 255 ) ) ;
60+ palette . Add ( new Scalar ( 0 , 212 , 187 , 255 ) ) ;
61+ palette . Add ( new Scalar ( 44 , 153 , 168 , 255 ) ) ;
62+ palette . Add ( new Scalar ( 0 , 194 , 255 , 255 ) ) ;
63+ palette . Add ( new Scalar ( 52 , 69 , 147 , 255 ) ) ;
64+ palette . Add ( new Scalar ( 100 , 115 , 255 , 255 ) ) ;
65+ palette . Add ( new Scalar ( 0 , 24 , 236 , 255 ) ) ;
66+ palette . Add ( new Scalar ( 132 , 56 , 255 , 255 ) ) ;
67+ palette . Add ( new Scalar ( 82 , 0 , 133 , 255 ) ) ;
68+ palette . Add ( new Scalar ( 203 , 56 , 255 , 255 ) ) ;
69+ palette . Add ( new Scalar ( 255 , 149 , 200 , 255 ) ) ;
70+ palette . Add ( new Scalar ( 255 , 55 , 199 , 255 ) ) ;
71+ }
72+
73+ protected virtual Mat preprocess ( Mat image )
74+ {
75+
76+ // Create a 4D blob from a frame.
77+
78+ int c = image . channels ( ) ;
79+ int h = ( int ) input_size . height ;
80+ int w = ( int ) input_size . width ;
81+
82+ if ( input_sizeMat == null )
83+ input_sizeMat = new Mat ( h , w , CvType . CV_8UC3 ) ; // [h, w]
84+
85+ int imh = image . height ( ) ;
86+ int imw = image . width ( ) ;
87+ int m = Mathf . Min ( imh , imw ) ;
88+ int top = ( int ) ( ( imh - m ) / 2f ) ;
89+ int left = ( int ) ( ( imw - m ) / 2f ) ;
90+ Mat image_crop = new Mat ( image , new OpenCVRect ( 0 , 0 , image . width ( ) , image . height ( ) ) . intersect ( new OpenCVRect ( left , top , m , m ) ) ) ;
91+ Imgproc . resize ( image_crop , input_sizeMat , new Size ( w , h ) ) ;
92+
93+ Mat blob = Dnn . blobFromImage ( input_sizeMat , 1.0 / 255.0 , input_size , Scalar . all ( 0 ) , true , false , CvType . CV_32F ) ; // HWC to NCHW, BGR to RGB
94+
95+ return blob ; // [1, 3, h, w]
96+
97+ }
98+
99+ public virtual Mat infer ( Mat image )
100+ {
101+ // cheack
102+ if ( image . channels ( ) != 3 )
103+ {
104+ Debug . Log ( "The input image must be in BGR format." ) ;
105+ return new Mat ( ) ;
106+ }
107+
108+ // Preprocess
109+ Mat input_blob = preprocess ( image ) ;
110+
111+ // Forward
112+ classification_net . setInput ( input_blob ) ;
113+
114+ List < Mat > output_blob = new List < Mat > ( ) ;
115+ classification_net . forward ( output_blob , classification_net . getUnconnectedOutLayersNames ( ) ) ;
116+
117+ // Postprocess
118+ Mat results = postprocess ( output_blob , image . size ( ) ) ;
119+
120+ input_blob . Dispose ( ) ;
121+ for ( int i = 0 ; i < output_blob . Count ; i ++ )
122+ {
123+ output_blob [ i ] . Dispose ( ) ;
124+ }
125+
126+ return results ;
127+ }
128+
129+ protected virtual Mat postprocess ( List < Mat > output_blob , Size original_shape )
130+ {
131+ Mat output_blob_0 = output_blob [ 0 ] ;
132+
133+ Mat results = output_blob_0 . clone ( ) ;
134+
135+ return results ; // [1, num_classes]
136+ }
137+
138+ protected virtual Mat softmax ( Mat src )
139+ {
140+ Mat dst = src . clone ( ) ;
141+
142+ Core . MinMaxLocResult result = Core . minMaxLoc ( src ) ;
143+ Scalar max = new Scalar ( result . maxVal ) ;
144+ Core . subtract ( src , max , dst ) ;
145+ Core . exp ( dst , dst ) ;
146+ Scalar sum = Core . sumElems ( dst ) ;
147+ Core . divide ( dst , sum , dst ) ;
148+
149+ return dst ;
150+ }
151+
152+ public virtual void visualize ( Mat image , Mat results , bool print_results = false , bool isRGB = false )
153+ {
154+ if ( image . IsDisposed )
155+ return ;
156+
157+ if ( results . empty ( ) || results . cols ( ) < classNames . Count )
158+ return ;
159+
160+ StringBuilder sb = null ;
161+
162+ if ( print_results )
163+ sb = new StringBuilder ( ) ;
164+
165+ ClassificationData bmData = getBestMatchData ( results ) ;
166+ int classId = ( int ) bmData . cls ;
167+ string label = getClassLabel ( bmData . cls ) + ", " + String . Format ( "{0:0.00}" , bmData . conf ) ;
168+
169+ Scalar c = palette [ classId % palette . Count ] ;
170+ Scalar color = isRGB ? c : new Scalar ( c . val [ 2 ] , c . val [ 1 ] , c . val [ 0 ] , c . val [ 3 ] ) ;
171+
172+ int [ ] baseLine = new int [ 1 ] ;
173+ Size labelSize = Imgproc . getTextSize ( label , Imgproc . FONT_HERSHEY_SIMPLEX , 1.0 , 1 , baseLine ) ;
174+
175+ float top = 20f + ( float ) labelSize . height ;
176+ float left = ( float ) ( image . width ( ) / 2 - labelSize . width / 2f ) ;
177+
178+ top = Mathf . Max ( ( float ) top , ( float ) labelSize . height ) ;
179+ Imgproc . rectangle ( image , new Point ( left , top - labelSize . height ) ,
180+ new Point ( left + labelSize . width , top + baseLine [ 0 ] ) , color , Core . FILLED ) ;
181+ Imgproc . putText ( image , label , new Point ( left , top ) , Imgproc . FONT_HERSHEY_SIMPLEX , 1.0 , Scalar . all ( 255 ) , 1 , Imgproc . LINE_AA ) ;
182+
183+ // Print results
184+ if ( print_results )
185+ {
186+ sb . AppendLine ( String . Format ( "Best match: " + getClassLabel ( bmData . cls ) + ", " + bmData ) ) ;
187+ }
188+
189+ if ( print_results )
190+ Debug . Log ( sb ) ;
191+ }
192+
193+ public virtual void dispose ( )
194+ {
195+ if ( classification_net != null )
196+ classification_net . Dispose ( ) ;
197+
198+ if ( input_sizeMat != null )
199+ input_sizeMat . Dispose ( ) ;
200+
201+ input_sizeMat = null ;
202+
203+ if ( getDataMat != null )
204+ getDataMat . Dispose ( ) ;
205+
206+ getDataMat = null ;
207+ }
208+
209+ protected virtual List < string > readClassNames ( string filename )
210+ {
211+ List < string > classNames = new List < string > ( ) ;
212+
213+ System . IO . StreamReader cReader = null ;
214+ try
215+ {
216+ cReader = new System . IO . StreamReader ( filename , System . Text . Encoding . Default ) ;
217+
218+ while ( cReader . Peek ( ) >= 0 )
219+ {
220+ string name = cReader . ReadLine ( ) ;
221+ classNames . Add ( name ) ;
222+ }
223+ }
224+ catch ( System . Exception ex )
225+ {
226+ Debug . LogError ( ex . Message ) ;
227+ return null ;
228+ }
229+ finally
230+ {
231+ if ( cReader != null )
232+ cReader . Close ( ) ;
233+ }
234+
235+ return classNames ;
236+ }
237+
238+ [ StructLayout ( LayoutKind . Sequential ) ]
239+ public readonly struct ClassificationData
240+ {
241+ public readonly float cls ;
242+ public readonly float conf ;
243+
244+ // sizeof(ClassificationData)
245+ public const int Size = 2 * sizeof ( float ) ;
246+
247+ public ClassificationData ( int cls , float conf )
248+ {
249+ this . cls = cls ;
250+ this . conf = conf ;
251+ }
252+
253+ public override string ToString ( )
254+ {
255+ return "cls:" + cls + " conf:" + conf ;
256+ }
257+ } ;
258+
259+ public virtual ClassificationData [ ] getData ( Mat results )
260+ {
261+ if ( results . empty ( ) )
262+ return new ClassificationData [ 0 ] ;
263+
264+ int num = results . cols ( ) ;
265+
266+ if ( getDataMat == null )
267+ {
268+ getDataMat = new Mat ( num , 2 , CvType . CV_32FC1 ) ;
269+ float [ ] arange = Enumerable . Range ( 0 , num ) . Select ( i => ( float ) i ) . ToArray ( ) ;
270+ getDataMat . col ( 0 ) . put ( 0 , 0 , arange ) ;
271+ }
272+
273+ Mat results_numx1 = results . reshape ( 1 , num ) ;
274+ results_numx1 . copyTo ( getDataMat . col ( 1 ) ) ;
275+
276+ var dst = new ClassificationData [ num ] ;
277+ OpenCVForUnity . UtilsModule . MatUtils . copyFromMat ( getDataMat , dst ) ;
278+
279+ return dst ;
280+ }
281+
282+ public virtual ClassificationData [ ] getSortedData ( Mat results , int topK = 5 )
283+ {
284+ if ( results . empty ( ) )
285+ return new ClassificationData [ 0 ] ;
286+
287+ int num = results . cols ( ) ;
288+
289+ if ( topK < 1 || topK > num ) topK = num ;
290+ var sortedData = getData ( results ) . OrderByDescending ( x => x . conf ) . Take ( topK ) . ToArray ( ) ;
291+
292+ return sortedData ;
293+ }
294+
295+ public virtual ClassificationData getBestMatchData ( Mat results )
296+ {
297+ if ( results . empty ( ) )
298+ return new ClassificationData ( ) ;
299+
300+ Core . MinMaxLocResult minmax = Core . minMaxLoc ( results ) ;
301+
302+ return new ClassificationData ( ( int ) minmax . maxLoc . x , ( float ) minmax . maxVal ) ;
303+ }
304+
305+ public virtual string getClassLabel ( float id )
306+ {
307+ int classId = ( int ) id ;
308+ string className = string . Empty ;
309+ if ( classNames != null && classNames . Count != 0 )
310+ {
311+ if ( classId >= 0 && classId < ( int ) classNames . Count )
312+ {
313+ className = classNames [ classId ] ;
314+ }
315+ }
316+ if ( string . IsNullOrEmpty ( className ) )
317+ className = classId . ToString ( ) ;
318+
319+ return className ;
320+ }
321+ }
322+ }
0 commit comments