de.htwdd.rosenkoenig.neuro.net.training
Class Trainer

java.lang.Object
  extended by de.htwdd.rosenkoenig.neuro.net.training.Trainer

public class Trainer
extends java.lang.Object

The Trainer class is needed to train a FeedForwardNet. The trainer is able to use different teaching algorithms, e.g. backpropagation.

To train a net you have to provide the trainer with a set of patterns and a teaching algorithm, specify the different parameters (e.g. learning rate, batchsize, maximum error) and start the training process. The trainer will run until the error is smaller than the spedified maximum error or until the maximum number of training cycles is reached.
After this training phase the trainer will start a validation phase (if validation coefficient > 0), which will run at most the same number of cycles as the training phase. In order to have to sets of training patterns, the trainer splits the pattern set into a teaching set and a validation set.

The following list shows the default values of the different training parameters:


Field Summary
private  TeachingAlgorithm algorithm
          This algorithm is used to train the net.
private  int batchsize
          The number of patterns that will be processed before the weights of the net are corrected.
private  int cycle
          The number of the current training cycle.
private  double error
          The last calculated training error of the net.
private  double flatspot
          This property is useful to highten the derivation of functions like Sigmoid, whose derivation is relatively flat.
private  double learningRate
           
private  org.apache.log4j.Logger log
           
private  int maxCycles
          The training process is stopped after maxCycles cycles, even if the training error might be higher than desired.
private  double maxError
          Training stops, if the training error is lower than this value.
private  double momentum
           
private  FeedForwardNet net
          The net that will be trained.
private  java.util.List<Pattern> patterns
          The patterns set provided by the user.
private  boolean pruneAfterTraining
          If set to true the net will be pruned after training.
private  double pruningLimit
          If pruning is activated, all weights smaller than pruningLimit will be set to 0 after training.
private  java.util.List<Pattern> trainingPatterns
          This set of patterns is extracted from patterns.
private  double validationCoeff
          If the value of this property is > 0, the provided pattern set will be split into a training set and a validation set.
private  double validationError
          The last calculated validation error of the net.
private  java.util.List<Pattern> validationPatterns
          This set of patterns is extracted from patterns.
 
Constructor Summary
Trainer(FeedForwardNet net, TeachingAlgorithm algorithm)
          Creates a new Trainer with default parameters.
 
Method Summary
 void addPatterns(int index, Pattern pattern)
          Inserts the specified element at the specified position in this list (optional operation)
 boolean addPatterns(Pattern pattern)
          Appends the specified element to the end of this list (optional operation).
protected  double calculateError(java.util.List<Pattern> patterns)
          Calculates the network's error on the provided pattern set.
 void clearPatterns()
          Removes all of the elements from this list (optional operation).
 boolean containsAllPatterns(java.util.Collection patterns)
          Returns true if this list contains all of the elements of the specified collection.
 boolean containsPatterns(Pattern pattern)
          Returns true if this list contains the specified element.
 TeachingAlgorithm getAlgorithm()
          Getter of the property algorithm
 int getBatchsize()
          Getter of the property batchsiye
 int getCycle()
           
 double getError()
          Getter of the property error
 double getFlatspot()
          Getter of the property flatspot
 double getLearningRate()
          Getter of the property learningRate
 int getMaxCycles()
          Getter of the property maxCycles
 double getMaxError()
          Getter of the property maxError
 double getMomentum()
          Getter of the property momentum
 FeedForwardNet getNet()
          Getter of the property net
 java.util.List<Pattern> getPatterns()
          Getter of the property patterns
 Pattern getPatterns(int i)
          Returns the element at the specified position in this list.
 boolean getPruneAfterTraining()
          Getter of the property pruneAfterTraining
 double getPruningLimit()
          Getter of the property pruningLimit
 double getValidationCoeff()
          Getter of the property validationCoeff
 double getValidationError()
          Getter of the property validationError
 boolean isPatternsEmpty()
          Returns true if this list contains no elements.
 java.util.Iterator<Pattern> patternsIterator()
          Returns an iterator over the elements in this list in proper sequence.
 int patternsSize()
          Returns the number of elements in this list.
 Pattern[] patternsToArray()
          Returns an array containing all of the elements in this list in proper sequence.
 Pattern[] patternsToArray(Pattern[] patterns)
          Returns an array containing all of the elements in this list in proper sequence; the runtime type of the returned array is that of the specified array.
private  void pruneNet()
          The weight of all connections with a weight lower than pruning limit are set to zero.
 java.lang.Object removePatterns(int index)
          Removes the element at the specified position in this list (optional operation).
 boolean removePatterns(Pattern pattern)
          Removes the first occurrence in this list of the specified element (optional operation).
 void setAlgorithm(TeachingAlgorithm algorithm)
          Setter of the property algorithm
 void setBatchsize(int batchsize)
          Setter of the property batchsiye
 void setFlatspot(double flatspot)
          Setter of the property flatspot
 void setLearningRate(double learningRate)
          Setter of the property learningRate
 void setMaxCycles(int maxCycles)
          Setter of the property maxCycles
 void setMaxError(double maxError)
          Setter of the property maxError
 void setMomentum(double momentum)
          Setter of the property momentum
 void setNet(FeedForwardNet net)
          Setter of the property net
 void setPatterns(java.util.List<Pattern> patterns)
          Setter of the property patterns
 void setPruneAfterTraining(boolean pruneAfterTraining)
          Setter of the property pruneAfterTraining
 void setPruningLimit(double pruningLimit)
          Setter of the property pruningLimit
 void setValidationCoeff(double validationCoeff)
          Setter of the property validationCoeff
 void train()
           train() controls the training process, which consists of the following steps: extract the training and validation sets from the provided pattern set run the teaching algorithm until the error is small enough (pattern sets are shuffled before each run run retraining until the validation error is small enough (optional) prune the net (optional)
 
Methods inherited from class java.lang.Object
clone, equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
 

Field Detail

log

private org.apache.log4j.Logger log

net

private FeedForwardNet net
The net that will be trained.


validationPatterns

private java.util.List<Pattern> validationPatterns
This set of patterns is extracted from patterns. Size of the validation set: patterns.size() * validationCoeff


trainingPatterns

private java.util.List<Pattern> trainingPatterns
This set of patterns is extracted from patterns. Size of the training set: patterns.size() * (1 - validationCoeff)


patterns

private java.util.List<Pattern> patterns
The patterns set provided by the user.


algorithm

private TeachingAlgorithm algorithm
This algorithm is used to train the net.


learningRate

private double learningRate

maxError

private double maxError
Training stops, if the training error is lower than this value.


maxCycles

private int maxCycles
The training process is stopped after maxCycles cycles, even if the training error might be higher than desired.


cycle

private int cycle
The number of the current training cycle.


momentum

private double momentum

validationCoeff

private double validationCoeff
If the value of this property is > 0, the provided pattern set will be split into a training set and a validation set.
Size of the validation set: validationCoeff * patterns.size()


flatspot

private double flatspot
This property is useful to highten the derivation of functions like Sigmoid, whose derivation is relatively flat.


batchsize

private int batchsize
The number of patterns that will be processed before the weights of the net are corrected.


pruneAfterTraining

private boolean pruneAfterTraining
If set to true the net will be pruned after training. In this case, pruning means setting all weights to 0 which are smaller than pruningLimit.


pruningLimit

private double pruningLimit
If pruning is activated, all weights smaller than pruningLimit will be set to 0 after training.


error

private double error
The last calculated training error of the net.


validationError

private double validationError
The last calculated validation error of the net.

Constructor Detail

Trainer

public Trainer(FeedForwardNet net,
               TeachingAlgorithm algorithm)
Creates a new Trainer with default parameters.

Parameters:
net - the FeedForwardNet to train
algorithm - the algorithm to use for training
Method Detail

getNet

public FeedForwardNet getNet()
Getter of the property net

Returns:
Returns the net.

setNet

public void setNet(FeedForwardNet net)
Setter of the property net

Parameters:
net - The net to set.

getPatterns

public java.util.List<Pattern> getPatterns()
Getter of the property patterns

Returns:
Returns the patterns.

getPatterns

public Pattern getPatterns(int i)
Returns the element at the specified position in this list.

Parameters:
index - index of element to return.
Returns:
the element at the specified position in this list.
See Also:
List.get(int)

patternsIterator

public java.util.Iterator<Pattern> patternsIterator()
Returns an iterator over the elements in this list in proper sequence.

Returns:
an iterator over the elements in this list in proper sequence.
See Also:
List.iterator()

isPatternsEmpty

public boolean isPatternsEmpty()
Returns true if this list contains no elements.

Returns:
true if this list contains no elements.
See Also:
List.isEmpty()

containsPatterns

public boolean containsPatterns(Pattern pattern)
Returns true if this list contains the specified element.

Parameters:
element - element whose presence in this list is to be tested.
Returns:
true if this list contains the specified element.
See Also:
List.contains(Object)

containsAllPatterns

public boolean containsAllPatterns(java.util.Collection patterns)
Returns true if this list contains all of the elements of the specified collection.

Parameters:
elements - collection to be checked for containment in this list.
Returns:
true if this list contains all of the elements of the specified collection.
See Also:
List.containsAll(Collection)

patternsSize

public int patternsSize()
Returns the number of elements in this list.

Returns:
the number of elements in this list.
See Also:
List.size()

patternsToArray

public Pattern[] patternsToArray()
Returns an array containing all of the elements in this list in proper sequence.

Returns:
an array containing all of the elements in this list in proper sequence.
See Also:
List.toArray()

patternsToArray

public Pattern[] patternsToArray(Pattern[] patterns)
Returns an array containing all of the elements in this list in proper sequence; the runtime type of the returned array is that of the specified array.

Parameters:
a - the array into which the elements of this list are to be stored.
Returns:
an array containing all of the elements in this list in proper sequence.
See Also:
List.toArray(Object[])

addPatterns

public void addPatterns(int index,
                        Pattern pattern)
Inserts the specified element at the specified position in this list (optional operation)

Parameters:
index - index at which the specified element is to be inserted.
element - element to be inserted.
See Also:
List.add(int,Object)

addPatterns

public boolean addPatterns(Pattern pattern)
Appends the specified element to the end of this list (optional operation).

Parameters:
element - element to be appended to this list.
Returns:
true (as per the general contract of the Collection.add method).
See Also:
List.add(Object)

removePatterns

public java.lang.Object removePatterns(int index)
Removes the element at the specified position in this list (optional operation).

Parameters:
index - the index of the element to removed.
Returns:
the element previously at the specified position.
See Also:
List.remove(int)

removePatterns

public boolean removePatterns(Pattern pattern)
Removes the first occurrence in this list of the specified element (optional operation).

Parameters:
element - element to be removed from this list, if present.
Returns:
true if this list contained the specified element.
See Also:
List.remove(Object)

clearPatterns

public void clearPatterns()
Removes all of the elements from this list (optional operation).

See Also:
List.clear()

setPatterns

public void setPatterns(java.util.List<Pattern> patterns)
Setter of the property patterns

Parameters:
patterns - the patterns to set.

getAlgorithm

public TeachingAlgorithm getAlgorithm()
Getter of the property algorithm

Returns:
Returns the algorithm.

setAlgorithm

public void setAlgorithm(TeachingAlgorithm algorithm)
Setter of the property algorithm

Parameters:
algorithm - The algorithm to set.

train

public void train()

train() controls the training process, which consists of the following steps:


pruneNet

private void pruneNet()
The weight of all connections with a weight lower than pruning limit are set to zero.


getLearningRate

public double getLearningRate()
Getter of the property learningRate

Returns:
Returns the learningRate.

setLearningRate

public void setLearningRate(double learningRate)
Setter of the property learningRate

Parameters:
learningRate - The learningRate to set.

getMaxError

public double getMaxError()
Getter of the property maxError

Returns:
Returns the maxError.

setMaxError

public void setMaxError(double maxError)
Setter of the property maxError

Parameters:
maxError - The maxError to set.

getMaxCycles

public int getMaxCycles()
Getter of the property maxCycles

Returns:
Returns the maxCycles.

setMaxCycles

public void setMaxCycles(int maxCycles)
Setter of the property maxCycles

Parameters:
maxCycles - The maxCycles to set.

calculateError

protected double calculateError(java.util.List<Pattern> patterns)
Calculates the network's error on the provided pattern set.

Parameters:
patterns - pattern set

getCycle

public int getCycle()

getMomentum

public double getMomentum()
Getter of the property momentum

Returns:
Returns the momentum.

setMomentum

public void setMomentum(double momentum)
Setter of the property momentum

Parameters:
momentum - The momentum to set.

getValidationCoeff

public double getValidationCoeff()
Getter of the property validationCoeff

Returns:
Returns the validationCoeff.

setValidationCoeff

public void setValidationCoeff(double validationCoeff)
Setter of the property validationCoeff

Parameters:
validationCoeff - The validationCoeff to set.

getFlatspot

public double getFlatspot()
Getter of the property flatspot

Returns:
Returns the flatspot.

setFlatspot

public void setFlatspot(double flatspot)
Setter of the property flatspot

Parameters:
flatspot - The flatspot to set.

getBatchsize

public int getBatchsize()
Getter of the property batchsiye

Returns:
Returns the batchsiye.

setBatchsize

public void setBatchsize(int batchsize)
Setter of the property batchsiye

Parameters:
batchsiye - The batchsiye to set.

getPruneAfterTraining

public boolean getPruneAfterTraining()
Getter of the property pruneAfterTraining

Returns:
Returns the pruneAfterTraining.

setPruneAfterTraining

public void setPruneAfterTraining(boolean pruneAfterTraining)
Setter of the property pruneAfterTraining

Parameters:
pruneAfterTraining - The pruneAfterTraining to set.

getPruningLimit

public double getPruningLimit()
Getter of the property pruningLimit

Returns:
Returns the pruningLimit.

setPruningLimit

public void setPruningLimit(double pruningLimit)
Setter of the property pruningLimit

Parameters:
pruningLimit - The pruningLimit to set.

getError

public double getError()
Getter of the property error

Returns:
Returns the error.

getValidationError

public double getValidationError()
Getter of the property validationError

Returns:
Returns the validationError.