001/*******************************************************************************
002 * Copyright (c) 2016 Diamond Light Source Ltd. and others.
003 * All rights reserved. This program and the accompanying materials
004 * are made available under the terms of the Eclipse Public License v1.0
005 * which accompanies this distribution, and is available at
006 * http://www.eclipse.org/legal/epl-v10.html
007 *
008 * Contributors:
009 *     Diamond Light Source Ltd - initial API and implementation
010 *******************************************************************************/
011package org.eclipse.january.dataset;
012
013import java.lang.reflect.Array;
014import java.util.ArrayList;
015import java.util.Arrays;
016import java.util.Collection;
017import java.util.List;
018import java.util.SortedSet;
019import java.util.TreeSet;
020
021public class ShapeUtils {
022
023        private ShapeUtils() {
024        }
025
026        /**
027         * Calculate total number of items in given shape
028         * @param shape dataset shape
029         * @return size
030         */
031        public static long calcLongSize(final int[] shape) {
032                if (shape == null) { // special case of null-shaped
033                        return 0;
034                }
035
036                final int rank = shape.length;
037                if (rank == 0) { // special case of zero-rank shape 
038                        return 1;
039                }
040        
041                double dsize = 1.0;
042                for (int i = 0; i < rank; i++) {
043                        // make sure the indexes isn't zero or negative
044                        if (shape[i] == 0) {
045                                return 0;
046                        } else if (shape[i] < 0) {
047                                throw new IllegalArgumentException(String.format(
048                                                "The %d-th is %d which is not allowed as it is negative", i, shape[i]));
049                        }
050        
051                        dsize *= shape[i];
052                }
053        
054                // check to see if the size is larger than an integer, i.e. we can't allocate it
055                if (dsize > Long.MAX_VALUE) {
056                        throw new IllegalArgumentException("Size of the dataset is too large to allocate");
057                }
058                return (long) dsize;
059        }
060
061        /**
062         * Calculate total number of items in given shape
063         * @param shape dataset shape
064         * @return size
065         */
066        public static int calcSize(final int[] shape) {
067                long lsize = calcLongSize(shape);
068        
069                // check to see if the size is larger than an integer, i.e. we can't allocate it
070                if (lsize > Integer.MAX_VALUE) {
071                        throw new IllegalArgumentException("Size of the dataset is too large to allocate");
072                }
073                return (int) lsize;
074        }
075
076        /**
077         * Check if shapes are broadcast compatible
078         * 
079         * @param ashape first shape
080         * @param bshape second shape
081         * @return true if they are compatible
082         */
083        public static boolean areShapesBroadcastCompatible(final int[] ashape, final int[] bshape) {
084                if (ashape == null || bshape == null) {
085                        return ashape == bshape;
086                }
087
088                if (ashape.length < bshape.length) {
089                        return areShapesBroadcastCompatible(bshape, ashape);
090                }
091        
092                for (int a = ashape.length - bshape.length, b = 0; a < ashape.length && b < bshape.length; a++, b++) {
093                        if (ashape[a] != bshape[b] && ashape[a] != 1 && bshape[b] != 1) {
094                                return false;
095                        }
096                }
097        
098                return true;
099        }
100
101        /**
102         * Check if shapes are compatible, ignoring extra axes of length 1
103         * 
104         * @param ashape first shape
105         * @param bshape second shape
106         * @return true if they are compatible
107         */
108        public static boolean areShapesCompatible(final int[] ashape, final int[] bshape) {
109                if (ashape == null || bshape == null) {
110                        return ashape == bshape;
111                }
112
113                List<Integer> alist = new ArrayList<Integer>();
114        
115                for (int a : ashape) {
116                        if (a > 1) alist.add(a);
117                }
118        
119                final int imax = alist.size();
120                int i = 0;
121                for (int b : bshape) {
122                        if (b == 1)
123                                continue;
124                        if (i >= imax || b != alist.get(i++))
125                                return false;
126                }
127        
128                return i == imax;
129        }
130
131        /**
132         * Check if shapes are compatible but skip axis
133         * 
134         * @param ashape first shape
135         * @param bshape second shape
136         * @param axis to skip
137         * @return true if they are compatible
138         */
139        public static boolean areShapesCompatible(final int[] ashape, final int[] bshape, final int axis) {
140                if (ashape == null || bshape == null) {
141                        return ashape == bshape;
142                }
143
144                if (ashape.length != bshape.length) {
145                        return false;
146                }
147        
148                final int rank = ashape.length;
149                for (int i = 0; i < rank; i++) {
150                        if (i != axis && ashape[i] != bshape[i]) {
151                                return false;
152                        }
153                }
154                return true;
155        }
156
157        /**
158         * Remove dimensions of 1 in given shape - from both ends only, if true
159         * 
160         * @param shape dataset shape
161         * @param onlyFromEnds if true, trim ends
162         * @return newly squeezed shape (or original if unsqueezed)
163         */
164        public static int[] squeezeShape(final int[] shape, boolean onlyFromEnds) {
165                int unitDims = 0;
166                int rank = shape.length;
167                int start = 0;
168        
169                if (onlyFromEnds) {
170                        int i = rank - 1;
171                        for (; i >= 0; i--) {
172                                if (shape[i] == 1) {
173                                        unitDims++;
174                                } else {
175                                        break;
176                                }
177                        }
178                        for (int j = 0; j <= i; j++) {
179                                if (shape[j] == 1) {
180                                        unitDims++;
181                                } else {
182                                        start = j;
183                                        break;
184                                }
185                        }
186                } else {
187                        for (int i = 0; i < rank; i++) {
188                                if (shape[i] == 1) {
189                                        unitDims++;
190                                }
191                        }
192                }
193        
194                if (unitDims == 0) {
195                        return shape;
196                }
197        
198                int[] newDims = new int[rank - unitDims];
199                if (unitDims == rank)
200                        return newDims; // zero-rank dataset
201        
202                if (onlyFromEnds) {
203                        rank = newDims.length;
204                        for (int i = 0; i < rank; i++) {
205                                newDims[i] = shape[i+start];
206                        }
207                } else {
208                        int j = 0;
209                        for (int i = 0; i < rank; i++) {
210                                if (shape[i] > 1) {
211                                        newDims[j++] = shape[i];
212                                        if (j >= newDims.length)
213                                                break;
214                                }
215                        }
216                }
217        
218                return newDims;
219        }
220
221        /**
222         * Remove dimension of 1 in given shape
223         * 
224         * @param shape dataset shape
225         * @param axis to remove
226         * @return newly squeezed shape
227         */
228        public static int[] squeezeShape(final int[] shape, int axis) {
229                if (shape == null) {
230                        return null;
231                }
232
233                final int rank = shape.length;
234                if (rank == 0) {
235                        return new int[0];
236                }
237                if (axis < 0) {
238                        axis += rank;
239                }
240                if (axis < 0 || axis >= rank) {
241                        throw new IllegalArgumentException("Axis argument is outside allowed range");
242                }
243                int[] nshape = new int[rank-1];
244                for (int i = 0; i < axis; i++) {
245                        nshape[i] = shape[i];
246                }
247                for (int i = axis+1; i < rank; i++) {
248                        nshape[i-1] = shape[i];
249                }
250                return nshape;
251        }
252
253        /**
254         * Get shape from object (array or list supported)
255         * @param obj object
256         * @return shape can be null if obj is null
257         */
258        public static int[] getShapeFromObject(final Object obj) {
259                if (obj == null) {
260                        return null;
261                }
262
263                ArrayList<Integer> lshape = new ArrayList<Integer>();
264                getShapeFromObj(lshape, obj, 0);
265
266                final int rank = lshape.size();
267                final int[] shape = new int[rank];
268                for (int i = 0; i < rank; i++) {
269                        shape[i] = lshape.get(i);
270                }
271        
272                return shape;
273        }
274
275        /**
276         * Get shape from object
277         * @param ldims
278         * @param obj
279         * @param depth
280         * @return true if there is a possibility of differing lengths
281         */
282        private static boolean getShapeFromObj(final ArrayList<Integer> ldims, Object obj, int depth) {
283                if (obj == null)
284                        return true;
285        
286                if (obj instanceof List<?>) {
287                        List<?> jl = (List<?>) obj;
288                        int l = jl.size();
289                        updateShape(ldims, depth, l);
290                        for (int i = 0; i < l; i++) {
291                                Object lo = jl.get(i);
292                                if (!getShapeFromObj(ldims, lo, depth + 1)) {
293                                        break;
294                                }
295                        }
296                        return true;
297                }
298                Class<? extends Object> ca = obj.getClass().getComponentType();
299                if (ca != null) {
300                        final int l = Array.getLength(obj);
301                        updateShape(ldims, depth, l);
302                        if (InterfaceUtils.isElementSupported(ca)) {
303                                return true;
304                        }
305                        for (int i = 0; i < l; i++) {
306                                Object lo = Array.get(obj, i);
307                                if (!getShapeFromObj(ldims, lo, depth + 1)) {
308                                        break;
309                                }
310                        }
311                        return true;
312                } else if (obj instanceof IDataset) {
313                        int[] s = ((IDataset) obj).getShape();
314                        for (int i = 0; i < s.length; i++) {
315                                updateShape(ldims, depth++, s[i]);
316                        }
317                        return true;
318                } else {
319                        return false; // not an array of any type
320                }
321        }
322
323        private static void updateShape(final ArrayList<Integer> ldims, final int depth, final int l) {
324                if (depth >= ldims.size()) {
325                        ldims.add(l);
326                } else if (l > ldims.get(depth)) {
327                        ldims.set(depth, l);
328                }
329        }
330
331        /**
332         * Get n-D position from given index
333         * @param n absolute index
334         * @param shape dataset shape
335         * @return n-D position
336         */
337        public static int[] getNDPositionFromShape(int n, int[] shape) {
338                if (shape == null) {
339                        return null;
340                }
341
342                int rank = shape.length;
343                if (rank == 0) {
344                        return new int[0];
345                }
346
347                if (rank == 1) {
348                        return new int[] { n };
349                }
350
351                int[] output = new int[rank];
352                for (rank--; rank > 0; rank--) {
353                        output[rank] = n % shape[rank];
354                        n /= shape[rank];
355                }
356                output[0] = n;
357        
358                return output;
359        }
360
361        /**
362         * Get flattened view index of given position
363         * @param shape dataset shape
364         * @param pos
365         *            the integer array specifying the n-D position
366         * @return the index on the flattened dataset
367         */
368        public static int getFlat1DIndex(final int[] shape, final int[] pos) {
369                final int imax = pos.length;
370                if (imax == 0) {
371                        return 0;
372                }
373        
374                return AbstractDataset.get1DIndexFromShape(shape, pos);
375        }
376
377        /**
378         * This function takes a dataset and checks its shape against another dataset. If they are both of the same size,
379         * then this returns with no error, if there is a problem, then an error is thrown.
380         * 
381         * @param g
382         *            The first dataset to be compared
383         * @param h
384         *            The second dataset to be compared
385         * @throws IllegalArgumentException
386         *             This will be thrown if there is a problem with the compatibility
387         */
388        public static void checkCompatibility(final ILazyDataset g, final ILazyDataset h) throws IllegalArgumentException {
389                if (!areShapesCompatible(g.getShape(), h.getShape())) {
390                        throw new IllegalArgumentException("Shapes do not match");
391                }
392        }
393
394        /**
395         * Check that axis is in range [-rank,rank)
396         * 
397         * @param rank number of dimensions
398         * @param axis dimension to check
399         * @return sanitized axis in range [0, rank)
400         * @since 2.1
401         */
402        public static int checkAxis(int rank, int axis) {
403                if (axis < 0) {
404                        axis += rank;
405                }
406        
407                if (axis < 0 || axis >= rank) {
408                        throw new IllegalArgumentException("Axis " + axis + " given is out of range [0, " + rank + ")");
409                }
410                return axis;
411        }
412
413        private static int[] convert(Collection<Integer> list) {
414                int[] array = new int[list.size()];
415                int i = 0;
416                for (Integer l : list) {
417                        array[i++] = l;
418                }
419                return array;
420        }
421
422        /**
423         * Check that all axes are in range [-rank,rank)
424         * @param rank number of dimensions
425         * @param axes to skip
426         * @return sanitized axes in range [0, rank) and sorted in increasing order
427         * @since 2.2
428         */
429        public static int[] checkAxes(int rank, int... axes) {
430                return convert(sanitizeAxes(rank, axes));
431        }
432
433        /**
434         * Check that all axes are in range [-rank,rank)
435         * @param rank number of dimensions
436         * @param axes to skip
437         * @return sanitized axes in range [0, rank) and sorted in increasing order
438         * @since 2.2
439         */
440        private static SortedSet<Integer> sanitizeAxes(int rank, int... axes) {
441                SortedSet<Integer> nAxes = new TreeSet<>(); 
442                for (int i = 0; i < axes.length; i++) {
443                        nAxes.add(checkAxis(rank, axes[i]));
444                }
445
446                return nAxes;
447        }
448
449        /**
450         * @param rank number of dimensions
451         * @param axes to skip
452         * @return remaining axes not given by input
453         * @since 2.2
454         */
455        public static int[] getRemainingAxes(int rank, int... axes) {
456                SortedSet<Integer> nAxes = sanitizeAxes(rank, axes);
457
458                int[] remains = new int[rank - axes.length];
459                int j = 0;
460                for (int i = 0; i < rank; i++) {
461                        if (!nAxes.contains(i)) {
462                                remains[j++] = i;
463                        }
464                }
465                return remains;
466        }
467
468        /**
469         * Remove axes from shape
470         * @param shape to use
471         * @param axes to remove
472         * @return reduced shape
473         * @since 2.2
474         */
475        public static int[] reduceShape(int[] shape, int... axes) {
476                int[] remain = getRemainingAxes(shape.length, axes);
477                for (int i = 0; i < remain.length; i++) {
478                        int a = remain[i];
479                        remain[i] = shape[a];
480                }
481                return remain;
482        }
483
484        /**
485         * Set reduced axes to 1
486         * @param shape input
487         * @param axes to set to 1
488         * @return shape with same rank
489         * @since 2.2
490         */
491        public static int[] getReducedShapeKeepRank(int[] shape, int... axes) {
492                int[] keep = shape.clone();
493                axes = checkAxes(shape.length, axes);
494                for (int i : axes) {
495                        keep[i] = 1;
496                }
497                return keep;
498        }
499
500        /**
501         * @param a first shape
502         * @param b second shape
503         * @return true if arrays only differs by unit entries
504         * @since 2.2
505         */
506        public static boolean differsByOnes(int[] a, int[] b) {
507                int aRank = a.length;
508                int bRank = b.length;
509                int ai = 0;
510                int bi = 0;
511                int al = 1;
512                int bl = 1;
513                do {
514                        while (ai < aRank && (al = a[ai++]) == 1) { // next non-unit dimension
515                        }
516                        while (bi < bRank && (bl = b[bi++]) == 1) {
517                        }
518                        if (al != bl) {
519                                return false;
520                        }
521                } while (ai < aRank && bi < bRank);
522
523                if (ai == aRank) {
524                        while (bi < bRank) {
525                                if (b[bi++] != 1) {
526                                        return false;
527                                }
528                        }
529                }
530                if (bi == bRank) {
531                        while (ai < aRank) {
532                                if (a[ai++] != 1) {
533                                        return false;
534                                }
535                        }
536                }
537                return true;
538        }
539
540        /**
541         * Calculate the padding difference between two shapes. Padding can be positive (negative)
542         * for added (removed) dimensions. NB positive or negative padding is given after matched
543         * dimensions
544         * @param aShape first shape
545         * @param bShape second shape
546         * @return padding can be null if shapes are equal
547         * @throws IllegalArgumentException if one shape is null but not the other, or if shapes do
548         * not possess common non-unit lengths
549         * @since 2.2
550         */
551        public static int[] calcShapePadding(int[] aShape, int[] bShape) {
552                if (Arrays.equals(aShape, bShape)) {
553                        return null;
554                }
555
556                if (aShape == null || bShape == null) {
557                        throw new IllegalArgumentException("If one shape is null then the other must be null too");
558                }
559
560                if (!differsByOnes(aShape, bShape)) {
561                        throw new IllegalArgumentException("Non-unit lengths in shapes must be equal");
562                }
563                int aRank = aShape.length;
564                int bRank = bShape.length;
565
566                int[] padding;
567                if (aRank == 0 || bRank == 0) {
568                        padding = new int[1];
569                        padding[0] = aRank == 0 ? bRank : -aRank;
570                        return padding;
571                }
572
573                padding = new int[Math.max(aRank, bRank) + 2];
574                int ai = 0;
575                int bi = 0;
576                int al = 0;
577                int bl = 0;
578                int pi = 0;
579                int p;
580                boolean aLeft = ai < aRank;
581                boolean bLeft = bi < bRank;
582                while (aLeft && bLeft) {
583                        if (aLeft) {
584                                al = aShape[ai++];
585                                aLeft = ai < aRank;
586                        }
587                        if (bLeft) {
588                                bl = bShape[bi++];
589                                bLeft = bi < bRank;
590                        }
591                        if (al != bl) {
592                                p = 0;
593                                while (al == 1 && aLeft) {
594                                        al = aShape[ai++];
595                                        aLeft = ai < aRank;
596                                        p--;
597                                }
598                                while (bl == 1 && bLeft) {
599                                        bl = bShape[bi++];
600                                        bLeft = bi < bRank;
601                                        p++;
602                                }
603                                padding[pi++] = p;
604                        }
605                        if (al == bl) {
606                                pi++;
607                        }
608                }
609                if (aLeft || bLeft) {
610                        p = 0;
611                        while (ai < aRank && aShape[ai++] == 1) {
612                                p--;
613                        }
614                        while (bi < bRank && bShape[bi++] == 1) {
615                                p++;
616                        }
617                        padding[pi++] = p;
618                }
619
620                return Arrays.copyOf(padding, pi);
621        }
622
623        static int[] padShape(int[] padding, int nr, int[] oldShape) {
624                if (padding == null) {
625                        return oldShape.clone();
626                }
627                int or = oldShape.length;
628                int[] newShape = new int[nr];
629                int di = 0;
630                for (int i = 0, si = 0; i < padding.length && si <= or && di < nr; i++) {
631                        int c = padding[i];
632                        if (c == 0) {
633                                newShape[di++] = oldShape[si++];
634                        } else if (c > 0) {
635                                int dim = di + c;
636                                while (di < dim) {
637                                        newShape[di++] = 1;
638                                }
639                        } else if (c < 0) {
640                                si -= c; // remove dimensions by skipping forward in source array (should check that they are unit entries)
641                        }
642                }
643                while (di < nr) {
644                        newShape[di++] = 1;
645                }
646                return newShape;
647        }
648}