001// License: GPL. For details, see LICENSE file.
002package org.openstreetmap.josm.tools;
003
004import java.awt.Dimension;
005import java.awt.geom.Point2D;
006import java.awt.geom.Rectangle2D;
007import java.awt.image.BufferedImage;
008import java.util.HashMap;
009import java.util.HashSet;
010import java.util.Map;
011import java.util.Objects;
012import java.util.Set;
013
014/**
015 * Image warping algorithm.
016 *
017 * Deforms an image geometrically according to a given transformation formula.
018 * @since 11858
019 */
020public final class ImageWarp {
021
022    private ImageWarp() {
023        // Hide default constructor
024    }
025
026    /**
027     * Transformation that translates the pixel coordinates.
028     */
029    public interface PointTransform {
030        /**
031         * Translates pixel coordinates.
032         * @param pt pixel coordinates
033         * @return transformed pixel coordinates
034         */
035        Point2D transform(Point2D pt);
036    }
037
038    /**
039     * Wrapper that optimizes a given {@link ImageWarp.PointTransform}.
040     *
041     * It does so by spanning a grid with certain step size. It will invoke the
042     * potentially expensive master transform only at those grid points and use
043     * bilinear interpolation to approximate transformed values in between.
044     * <p>
045     * For memory optimization, this class assumes that rows are more or less scanned
046     * one-by-one as is done in {@link ImageWarp#warp}. I.e. this transform is <em>not</em>
047     * random access in the y coordinate.
048     */
049    public static class GridTransform implements ImageWarp.PointTransform {
050
051        private final double stride;
052        private final ImageWarp.PointTransform trfm;
053
054        private final Map<Integer, Map<Integer, Point2D>> cache;
055
056        private final boolean consistencyTest;
057        private final Set<Integer> deletedRows;
058
059        /**
060         * Create a new GridTransform.
061         * @param trfm the master transform, that needs to be optimized
062         * @param stride step size
063         */
064        public GridTransform(ImageWarp.PointTransform trfm, double stride) {
065            this.trfm = trfm;
066            this.stride = stride;
067            this.cache = new HashMap<>();
068            this.consistencyTest = Logging.isDebugEnabled();
069            if (consistencyTest) {
070                deletedRows = new HashSet<>();
071            } else {
072                deletedRows = null;
073            }
074        }
075
076        @Override
077        public Point2D transform(Point2D pt) {
078            int xIdx = (int) Math.floor(pt.getX() / stride);
079            int yIdx = (int) Math.floor(pt.getY() / stride);
080            double dx = pt.getX() / stride - xIdx;
081            double dy = pt.getY() / stride - yIdx;
082            Point2D value00 = getValue(xIdx, yIdx);
083            Point2D value01 = getValue(xIdx, yIdx + 1);
084            Point2D value10 = getValue(xIdx + 1, yIdx);
085            Point2D value11 = getValue(xIdx + 1, yIdx + 1);
086            double valueX = (value00.getX() * (1-dx) + value10.getX() * dx) * (1-dy) +
087                    (value01.getX() * (1-dx) + value11.getX() * dx) * dy;
088            double valueY = (value00.getY() * (1-dx) + value10.getY() * dx) * (1-dy) +
089                    (value01.getY() * (1-dx) + value11.getY() * dx) * dy;
090            return new Point2D.Double(valueX, valueY);
091        }
092
093        private Point2D getValue(int xIdx, int yIdx) {
094            return getRow(yIdx).computeIfAbsent(xIdx, k -> trfm.transform(new Point2D.Double(xIdx * stride, yIdx * stride)));
095        }
096
097        private Map<Integer, Point2D> getRow(int yIdx) {
098            cleanUp(yIdx - 3);
099            Map<Integer, Point2D> row = cache.get(yIdx);
100            if (row == null) {
101                row = new HashMap<>();
102                cache.put(yIdx, row);
103                if (consistencyTest) {
104                    // should not create a row that has been deleted before
105                    if (deletedRows.contains(yIdx)) throw new AssertionError();
106                    // only ever cache 3 rows at once
107                    if (cache.size() > 3) throw new AssertionError();
108                }
109            }
110            return row;
111        }
112
113        // remove rows from cache that will no longer be used
114        private void cleanUp(int yIdx) {
115            Map<Integer, Point2D> del = cache.remove(yIdx);
116            if (consistencyTest && del != null) {
117                // should delete each row only once
118                if (deletedRows.contains(yIdx)) throw new AssertionError();
119                deletedRows.add(yIdx);
120            }
121        }
122    }
123
124    /**
125     * Interpolation method.
126     */
127    public enum Interpolation {
128        /**
129         * Nearest neighbor.
130         *
131         * Simplest possible method. Faster, but not very good quality.
132         */
133        NEAREST_NEIGHBOR,
134
135        /**
136         * Bilinear.
137         *
138         * Decent quality.
139         */
140        BILINEAR;
141    }
142
143    /**
144     * Warp an image.
145     * @param srcImg the original image
146     * @param targetDim dimension of the target image
147     * @param invTransform inverse transformation (translates pixel coordinates
148     * of the target image to pixel coordinates of the original image)
149     * @param interpolation the interpolation method
150     * @return the warped image
151     */
152    public static BufferedImage warp(BufferedImage srcImg, Dimension targetDim, PointTransform invTransform, Interpolation interpolation) {
153        BufferedImage imgTarget = new BufferedImage(targetDim.width, targetDim.height, BufferedImage.TYPE_INT_ARGB);
154        Rectangle2D srcRect = new Rectangle2D.Double(0, 0, srcImg.getWidth(), srcImg.getHeight());
155        for (int j = 0; j < imgTarget.getHeight(); j++) {
156            for (int i = 0; i < imgTarget.getWidth(); i++) {
157                Point2D srcCoord = invTransform.transform(new Point2D.Double(i, j));
158                if (srcRect.contains(srcCoord)) {
159                    int rgba;
160                    switch (interpolation) {
161                        case NEAREST_NEIGHBOR:
162                            rgba = getColor((int) Math.round(srcCoord.getX()), (int) Math.round(srcCoord.getY()), srcImg);
163                            break;
164                        case BILINEAR:
165                            int x0 = (int) Math.floor(srcCoord.getX());
166                            double dx = srcCoord.getX() - x0;
167                            int y0 = (int) Math.floor(srcCoord.getY());
168                            double dy = srcCoord.getY() - y0;
169                            int c00 = getColor(x0, y0, srcImg);
170                            int c01 = getColor(x0, y0 + 1, srcImg);
171                            int c10 = getColor(x0 + 1, y0, srcImg);
172                            int c11 = getColor(x0 + 1, y0 + 1, srcImg);
173                            rgba = 0;
174                            // loop over color components: blue, green, red, alpha
175                            for (int ch = 0; ch <= 3; ch++) {
176                                int shift = 8 * ch;
177                                int chVal = (int) Math.round(
178                                    (((c00 >> shift) & 0xff) * (1-dx) + ((c10 >> shift) & 0xff) * dx) * (1-dy) +
179                                    (((c01 >> shift) & 0xff) * (1-dx) + ((c11 >> shift) & 0xff) * dx) * dy);
180                                rgba |= chVal << shift;
181                            }
182                            break;
183                        default:
184                            throw new AssertionError(Objects.toString(interpolation));
185                    }
186                    imgTarget.setRGB(i, j, rgba);
187                }
188            }
189        }
190        return imgTarget;
191    }
192
193    private static int getColor(int x, int y, BufferedImage img) {
194        // border strategy: continue with the color of the outermost pixel,
195        return img.getRGB(
196                Utils.clamp(x, 0, img.getWidth() - 1),
197                Utils.clamp(y, 0, img.getHeight() - 1));
198    }
199}