Home CPSC 240

Java Threads

 

Overview

The following graph shows the clock speed of Intel processors from the 70's until 2015.

Note that the Y-axis on this graph is exponential. After growing rapidly for decades, clock speeds have not increased at all in recent years.

Recent computers have not increased single-core performance, but come with multiple cores. To take advantage of this, we need to write programs that use more than one thread.

A thread is an independent piece of execution. A multi-threaded program has more than one thread executing in parallel.


 

Threads in Java

To create a thread in Java, first create a class that extends the Thread class is created and overrides the "run" method.

Then, we create an object of that class. Lastly the "start" method of the thread is called. This creates a new thread of execution which begins in the "run" method.

This example shows a very simple multi-threaded program:


// Simple.java

class MyThread extends Thread {
    @Override
    public void run() {
        for (int i = 0; i < 10; i++) {
            try {
                System.out.printf("In thread, i = %d.\n", i);
                Thread.sleep(1000);
            } catch (InterruptedException e) {
                System.out.printf("Interrupted!\n");
            }
        }
    }
}

public class Simple {
    public static void main(String args[]) {
        MyThread t = new MyThread();
        t.start();

        for (int i = 0; i < 10; i++) {
            try {
                System.out.printf("In main, i = %d.\n", i);
                Thread.sleep(1000);
            } catch (InterruptedException e) {
                System.out.printf("Interrupted!\n");
            }
        }
    }
}

 

Sum Example

Multi-threading is good when one wants to do a computation that can be done in parallel. For example, we could sum a range of numbers by breaking the range into sub-ranges, summing them, and adding the partial sums.

The following example attempts to implement this:


// Sum1.java

class SumThread extends Thread {
    private int from, to, sum;

    public SumThread(int from, int to) {
        this.from = from;
        this.to = to;
        sum = 0;
    }

    @Override
    public void run() {
        for (int i = from; i <= to; i++) {
            sum += i;
        } 
    }

    public int getSum() {
        return sum;
    }
}


public class Sum1 {
    public static void main(String args[]) {
        SumThread t1 = new SumThread(1, 500);
        SumThread t2 = new SumThread(501, 1000);
        t1.start();
        t2.start();

        System.out.printf("The sum of 1-1000 is %d.\n", t1.getSum() + t2.getSum());
    }
}

 

Joining Threads

The program above does not work because the two threads are still busy computing their sum when main is running. Main calls the getSum functions before the sum is fully computed.

To fix this, we need the "join" function which waits for a thread to finish as in the following program:


// Sum2.java

class SumThread extends Thread {
    private int from, to, sum;

    public SumThread(int from, int to) {
        this.from = from;
        this.to = to;
        sum = 0;
    }

    @Override
    public void run() {
        for (int i = from; i <= to; i++) {
            sum += i;
        } 
    }

    public int getSum() {
        return sum;
    }
}


public class Sum2 {
    public static void main(String args[]) {
        SumThread t1 = new SumThread(1, 500);
        SumThread t2 = new SumThread(501, 1000);
        t1.start();
        t2.start();

        // wait for the threads to finish!
        try {
            t1.join();
            t2.join();
        } catch (InterruptedException e) {
            System.out.println("Interrupted");
        }

        System.out.printf("The sum of 1-1000 is %d.\n", t1.getSum() + t2.getSum());
    }
}

 

More Threads

Rather than setting the number of threads to two, we may want to allow any number of threads dynamically. To do this, we can create an array of threads and split up the work dynamically. This is done in the following example:


// Sum3.java

class SumThread extends Thread {
    private int from, to, sum;

    public SumThread(int from, int to) {
        this.from = from;
        this.to = to;
        sum = 0;
    }

    @Override
    public void run() {
        for (int i = from; i <= to; i++) {
            sum += i;
        } 
    }

    public int getSum() {
        return sum;
    }
}


public class Sum3 {
    public static void main(String args[]) {
        if (args.length != 3) {
            System.out.println("Pass threads, start and end.");
            return;
        }

        // get arguments
        int num_threads = Integer.parseInt(args[0]);
        int from = Integer.parseInt(args[1]);
        int to = Integer.parseInt(args[2]);

        // an array of threads
        SumThread [] threads = new SumThread[num_threads];

        // fill in the start/end ranges for each
        int step = (to - from) / num_threads;
        for (int i = 0; i < num_threads; i++) {
            int start = from + step * i;
            int stop = (start + step) - 1;

            // make sure we go all the way to the end on last thread
            if (i == (num_threads - 1)) {
                stop = to;         
            }

            System.out.printf("Thread %d sums from %d to %d.\n", i, start, stop);
            threads[i] = new SumThread(start, stop);
        }

        // start them all
        for (int i = 0; i < num_threads; i++) {
            threads[i].start();
        }

        // wait for all the threads to finish!
        try {
            for (int i = 0; i < num_threads; i++) {
                threads[i].join();
            }
        } catch (InterruptedException e) {
            System.out.println("Interrupted");
        }

        // calculate total sum
        int total_sum = 0;
        for (int i = 0; i < num_threads; i++) {
            total_sum += threads[i].getSum();
        }

        System.out.printf("The sum of %d-%d is %d.\n", from, to, total_sum);
    }
}

This program takes the number of threads and the range as command line parameters. Unfortunately, due to integer overflow, this can't be used for very large numbers.

To solve that, we can use the Java BigInteger class. The following program uses BigIntegers for the sums:


// Sum4.java

import java.math.BigInteger;

class SumThread extends Thread {
    private BigInteger from, to, sum;
    
    public SumThread(BigInteger from, BigInteger to) {
        this.from = from;
        this.to = to;
        sum = new BigInteger("0");
    }

    @Override
    public void run() {
        BigInteger one = new BigInteger("1");
        for (BigInteger i = from; i.compareTo(to) <= 0; i = i.add(one)) {
            sum = sum.add(i);
        } 
    }

    public BigInteger getSum() {
        return sum;
    }
}


public class Sum4 {
    public static void main(String args[]) {
        if (args.length != 3) {
            System.out.println("Pass threads, start and end.");
            return;
        }

        // get arguments
        int num_threads = Integer.parseInt(args[0]);
        BigInteger from = new BigInteger(args[1]);
        BigInteger to = new BigInteger(args[2]);

        // an array of threads
        SumThread [] threads = new SumThread[num_threads];

        //  fill in the start/end ranges for each
        BigInteger step = to.subtract(from).divide(BigInteger.valueOf(num_threads));
        for (int i = 0; i < num_threads; i++) {
            BigInteger start = from.add(step.multiply(BigInteger.valueOf(i)));
            BigInteger stop = start.add(step).subtract(BigInteger.valueOf(1));

            // make sure we go all the way to the end on last thread
            if (i == (num_threads - 1)) {
                stop = to;         
            }

            System.out.printf("Thread %s sums from %s to %s.\n", i, start.toString(), stop.toString());
            threads[i] = new SumThread(start, stop);
        }

        // start them all
        for (int i = 0; i < num_threads; i++) {
            threads[i].start();
        }

        // wait for all the threads to finish!
        try {
            for (int i = 0; i < num_threads; i++) {
                threads[i].join();
            }
        } catch (InterruptedException e) {
            System.out.println("Interrupted");
        }

        // calculate total sum
        BigInteger total_sum = new BigInteger("0");
        for (int i = 0; i < num_threads; i++) {
            total_sum = total_sum.add(threads[i].getSum());
        }

        System.out.printf("The sum of %s-%s is %s.\n", from.toString(), to.toString(), total_sum.toString());
    }
}

 

Speedup

If we run this program to compute the sum of 1 - 100,000,000, we can vary the number of threads and see the time taken.

This was the result for one test on a particular machine:

ThreadsSeconds
122.398
215.303
410.213
88.098
166.671
326.112
645.956
1285.005
2564.688

Why doesn't this increase linearly?

Why does the speed-up level off?


 

Sharing Between Threads

In the examples above, each thread stores its own data. In some cases multiple threads will need to share data.

Sharing data amongst threads can be tricky as the following example shows:


import java.util.*;

// a simple counter class
class Counter {
    private int value;
    
    public Counter() {
        value = 0;
    }

    public void increment() {
        value++;
    }

    public int getValue() {
        return value;
    }
}

// a thread class to increment a counter by a certain amount in parallel
class IncThread extends Thread {
    private Counter counter;
    private int amount;
    
    public IncThread(Counter c, int amt) {
        counter = c;
        amount = amt;
    }

    @Override
    public void run() {
        for (int i = 0; i < amount; i++) {
            counter.increment();
        }
    }
}

public class Sharing1 {
    public static void main(String args[]) {
        // create a counter
        Counter counter = new Counter();

        // make some threads increment the counter
        IncThread [] threads = new IncThread[100];
        for (int i = 0; i < 100; i++) {
            threads[i] = new IncThread(counter, 5);
        }

        // start them all
        for (int i = 0; i < 100; i++) {
            threads[i].start();
        }

        // wait for them to finish
        try {
            for (int i = 0; i < 100; i++) {
                threads[i].join();
            }
        } catch (InterruptedException e) {

        }

        // get the value
        System.out.println(counter.getValue());
    }
}

This program should print the value 500, but might not get the correct answer.


 

Synchronization

The problem with this program is that the threads are each calling the counter.increment method. Inside that method, they read the counter value, then increment it, then over-write it.

What can happen is that one thread will read a "stale" value and the update will be lost.

To fix this issue, we need to synchronize access to the increment method. This can be done multiple ways, but the simplest is by marking the method as synchronized.

This means that only one thread can execute that method at any one time. This program synchronizes the increment method:


import java.util.*;

// a simple counter class
class Counter {
    private int value;

    public Counter() {
        value = 0;
    }

    public synchronized void increment() {
        int current_value = value;

        try {
            Thread.sleep(1);
        } catch (InterruptedException e) {

        }

        int new_value = current_value + 1;
        value = new_value;
    }

    public int getValue() {
        return value;
    }
}



// a thread class to increment a counter by a certain amount in parallel
class IncThread extends Thread {
    private Counter counter;
    private int amount;

    public IncThread(Counter c, int amt) {
        counter = c;
        amount = amt;
    }

    @Override
    public void run() {
        for (int i = 0; i < amount; i++) {
            counter.increment();
        }
    }
}



public class Sharing2 {
    public static void main(String args[]) {
        // create a counter
        Counter counter = new Counter();

        // make some threads increment the counter
        IncThread [] threads = new IncThread[100];
        for (int i = 0; i < 100; i++) {
            threads[i] = new IncThread(counter, 5);
        }

        // start them all
        for (int i = 0; i < 100; i++) {
            threads[i].start();
        }

        // wait for them to finish
        try {
            for (int i = 0; i < 100; i++) {
                threads[i].join();
            }
        } catch (InterruptedException e) {

        }

        // get the value
        System.out.println(counter.getValue());
    }
}

Copyright © 2024 Ian Finlayson | Licensed under a Creative Commons BY-NC-SA 4.0 License.