import javax.swing.*;
import java.awt.*;
import java.awt.event.*;
import java.util.*;

public class GradDescent extends JFrame implements ActionListener
{
  private JButton but1;           // button for generating the point objects
  private JTextField tf1;         // input text field for the # of objects
  private GradDescentAlgo drawPanel;     // the main panel for drawing
  public JLabel dropped;

  public static void main(String[] args)
  {
    GradDescent application = new GradDescent();
    application.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
  }

  public GradDescent()
  {
    Toolkit toolkit = Toolkit.getDefaultToolkit();  // opening the Java
    Dimension screenSize = toolkit.getScreenSize(); // application window
    int screenWidth = screenSize.width;             // with a specific size,
    int screenHeight = screenSize.height;           // title, and position
    setSize(3*screenWidth/4, 3*screenHeight/4);     // on the screen
    setLocation(screenWidth/8, screenHeight/8);
    setTitle("Gradient Descent method for label placement");

    Container pane = getContentPane();              // building GUI
    pane.setLayout(new BorderLayout());
    dropped = new JLabel("", SwingConstants.LEFT);
    drawPanel = new GradDescentAlgo(dropped);       // main panel for drawing
    drawPanel.setBackground(new Color(0xffffe0));   // the rectangles
    pane.add(drawPanel, BorderLayout.CENTER);

    JPanel controlPanel = new JPanel();             // the interface panel
    controlPanel.setBackground(new Color(0xcccc99));
    controlPanel.setLayout(new GridLayout(5,1));
    tf1 = new JTextField("100");
    but1 = new JButton("Generate");
    but1.addActionListener(this);
    controlPanel.add(new JLabel("# of rects", SwingConstants.CENTER));
    controlPanel.add(tf1);
    controlPanel.add(but1);
    controlPanel.add(new JLabel("---------", SwingConstants.CENTER));
    controlPanel.add(dropped);
    JPanel controlPanel2 = new JPanel();
    controlPanel2.setBackground(new Color(0xcccc99));
    controlPanel2.add(controlPanel);
    pane.add(controlPanel2, BorderLayout.EAST);

    setVisible(true);
    drawPanel.repaint();
  }

  public void actionPerformed(ActionEvent e)
  {
    if (e.getSource() == but1)
    {
      int n = getNumber(tf1);                  // # of rectanges to generate
      if (n < 0) return;
      drawPanel.algo(n);                       // run the placement algo
      drawPanel.repaint();                     // update the drawing
    }
  }

  public int getNumber(JTextField tf)          // get a number from a text
  {                                            // field with validation
    int n = 100;
    try
    {
      n = Integer.parseInt(tf.getText());
      if (n <= 0)
        throw new NumberFormatException();
    }
    catch(NumberFormatException e)
    {
      JOptionPane.showMessageDialog(this, "Enter a positive integer",
      "Error message", JOptionPane.ERROR_MESSAGE);
      return(-1);
    }
    return(n);
  }
}

class GradDescentAlgo extends JPanel
{
  private class Obj                            // class to represent an object  
  {
    int x, y, delta, w, h, pos;
    String label;

    public Obj(int x, int y, int pos, String label)
    {
      this.x = x;
      this.y = y;
      this.pos = pos;
      this.label = label;
      Graphics g = getGraphics();
      w = g.getFontMetrics().stringWidth(label);
      h = g.getFontMetrics().getAscent();
      delta = 5;
    }
  }

  private int n;
  private Obj[] objects;
  private JLabel dropped; 
  private int totalCost = 0;
  private int[][] adjList;
  private boolean[] v;
  private int[][] rects; 

  public GradDescentAlgo(JLabel dropped)
  {
    this.dropped = dropped;
  }

  public void generateObjects(int n)
  {
    this.n = n;
    int n4 = n*4; 
    objects = new Obj[n];
    v = new boolean[n4];
    int width = getWidth();
    int height = getHeight();
    rects = new int[n4][4];     // array of rect corner coords

    // create the all the label positions bounding boxes
    ArrayList<ArrayList<Integer>> adjLIST = new ArrayList<ArrayList<Integer>>();
    for (int i=0; i<n; i++)
    {
      int y = 20 + (int)Math.round(Math.random()*(height - 40));
      int x = 40 + (int)Math.round(Math.random()*(width - 80));  
      int pos = (int)(Math.random()*5);
      String s = "label" + i;
      Obj o = new Obj(x, y, pos, s);
      objects[i] = o;
      if (pos != 4)
        v[4*i + pos] = true;
      totalCost += pos;

      int i4 = i*4;
      rects[i4][0] = o.x + o.delta;
      rects[i4][1] = o.y - o.delta - o.h;
      rects[i4][2] = o.x + o.delta + o.w;
      rects[i4][3] = o.y - o.delta;
      rects[i4+1][0] = rects[i4][0] - 2*o.delta - o.w;
      rects[i4+1][1] = rects[i4][1];
      rects[i4+1][2] = rects[i4][2] - 2*o.delta - o.w;
      rects[i4+1][3] = rects[i4][3];
      rects[i4+2][0] = rects[i4+1][0];
      rects[i4+2][1] = rects[i4][1] + 2*o.delta + o.h;
      rects[i4+2][2] = rects[i4+1][2];
      rects[i4+2][3] = rects[i4][3] + 2*o.delta + o.h;
      rects[i4+3][0] = rects[i4][0];
      rects[i4+3][1] = rects[i4][1] + 2*o.delta + o.h;
      rects[i4+3][2] = rects[i4][2];
      rects[i4+3][3] = rects[i4][3] + 2*o.delta + o.h;
    }

    for (int i=0; i<n4; i++)                 // create the intersection graph
    {
      adjLIST.add(new ArrayList<Integer>());
      for (int j=0; j<i; j++)
        if (rectIntersect(i,j))
        {
          adjLIST.get(i).add(new Integer(j));
          adjLIST.get(j).add(new Integer(i));
          if (v[i] && v[j])
            totalCost += 5;
        }
    }

    adjList = new int[n4][];
    for (int i=0; i<n4; i++)       // convert the dynamic list to 2-dim array
    {
      adjList[i] = new int[adjLIST.get(i).size()];
      for (int j=0; j<adjLIST.get(i).size(); j++)
        adjList[i][j] = adjLIST.get(i).get(j).intValue();
    }
 
  }

  public void algo(int n)                  // main method doing all the job
  {
    generateObjects(n);
    int costDiff = 0;
    int delta_i = 0;
    do
    {
      costDiff = Integer.MAX_VALUE;
      int minInd = 0;
      int newPos = 0;

      for (int i=0; i<n; i++)              // loop over all displayed labels
      {
        delta_i = -objects[i].pos;         // take off the current label

        if (objects[i].pos != 4)
        {
          int oldPos_i = 4*i + objects[i].pos;
          for (int j=0; j<adjList[oldPos_i].length; j++)
            if (v[adjList[oldPos_i][j]])  // update the cost of taking-off
              delta_i -= 10;
        }

        int min = Integer.MAX_VALUE;       // find a best position for label_i
        int newPos_i = 0;
        for (int k=0; k<4; k++)
        {
          int d = k;
          for (int j=0; j<adjList[4*i+k].length; j++)
            if (v[adjList[4*i+k][j]])
              d += 10;
          if (d < min)
          {
            min = d;                       // min is the min cost of the new
            newPos_i = k;                  // new position of label_i
          }
        }
        delta_i += min;                    // update the costDiff for label_i

        if (delta_i < costDiff)            // update the min cost for the
        {                                  // label relpacement
          costDiff = delta_i;
          minInd = i;
          newPos = newPos_i;
        } 
      }

      // reposition the label minInd from its position to position newPos
      if (objects[minInd].pos != 4)
      {
        int oldIndex = 4*minInd + objects[minInd].pos;
        for (int j=0; j<adjList[oldIndex].length; j++)
          if (v[adjList[oldIndex][j]])      // update the cost of taking-off
            totalCost -= 10;
        v[oldIndex] = false;
      }

      int newIndex = 4*minInd + newPos;
      v[newIndex] = true;
      totalCost = totalCost - objects[minInd].pos + newPos;
      objects[minInd].pos = newPos;

      for (int j=0; j<adjList[newIndex].length; j++)
        if (v[adjList[newIndex][j]])      // update the cost of taking-off
          totalCost += 10;
    }
    while(costDiff < 0);
  }

  // this method checkes if two bounding boxes for labels overlap
  private boolean rectIntersect(int i, int j)
  {
    if (rects[i][2] < rects[j][0] || rects[j][2] < rects[i][0] ||
        rects[i][1] > rects[j][3] || rects[j][1] > rects[i][3])
      return(false);
    else
      return(true);
  } 

  public void paintComponent(Graphics g)
  {
    super.paintComponent(g);
    if (objects != null)
    {
      int notShown = 0;                     // counter for hidden labels
      for (int i=0; i<objects.length; i++)
      {
        Obj o = objects[i];
        int x = o.x;
        int y = o.y;
        if (o.pos == 4)                     // set object color object
        {
          g.setColor(Color.red);
          notShown++;
        }
        else 
          g.setColor(Color.black);
        g.fillOval(x-5, y-5, 10, 10);       // display the object

        int dx = 1;
        int dy = 1;
        switch (o.pos)                      // compute the label position
        {
          case 0: x += o.delta; y -= o.h+o.delta; dx=1; dy=-1; break;
          case 1: x -= o.w+o.delta; y -= o.h+o.delta; dx=-1; dy=-1; break;
          case 2: x -= o.w+o.delta; y += o.delta; dx=-1; dy=1; break;
          case 3: x += o.delta; y += o.delta; dx=1; dy=1; break;
          case 4: dx=0; dy=0; continue;
        }

        g.setColor(Color.blue);             // display the label
        g.drawString(o.label, x, y+o.h);
        g.drawRect(x, y, o.w, o.h);
        g.drawLine(o.x, o.y, o.x+dx*o.delta, o.y+dy*o.delta);
      }
      dropped.setText("dropped: " + notShown); 
    }
  }
}

