|
||||||||||
| PREV CLASS NEXT CLASS | FRAMES NO FRAMES | |||||||||
| SUMMARY: NESTED | FIELD | CONSTR | METHOD | DETAIL: FIELD | CONSTR | METHOD | |||||||||
java.lang.Objectde.htwdd.rosenkoenig.neuro.net.training.Trainer
public class Trainer
The Trainer class is needed to train a FeedForwardNet. The trainer is able to use different teaching algorithms, e.g. backpropagation.
To train a net you have to provide the trainer with a set of patterns and a teaching algorithm,
specify the different parameters (e.g. learning rate, batchsize, maximum error) and start the
training process. The trainer will run until the error is smaller than the spedified maximum
error or until the maximum number of training cycles is reached.
After this training phase the trainer will start a validation phase (if validation coefficient >
0), which will run at most the same number of cycles as the training phase. In order to have to
sets of training patterns, the trainer splits the pattern set into a teaching set and a
validation set.
The following list shows the default values of the different training parameters:
| Field Summary | |
|---|---|
private TeachingAlgorithm |
algorithm
This algorithm is used to train the net. |
private int |
batchsize
The number of patterns that will be processed before the weights of the net are corrected. |
private int |
cycle
The number of the current training cycle. |
private double |
error
The last calculated training error of the net. |
private double |
flatspot
This property is useful to highten the derivation of functions like Sigmoid, whose derivation is relatively flat. |
private double |
learningRate
|
private org.apache.log4j.Logger |
log
|
private int |
maxCycles
The training process is stopped after maxCycles cycles, even if the training error might be higher than desired. |
private double |
maxError
Training stops, if the training error is lower than this value. |
private double |
momentum
|
private FeedForwardNet |
net
The net that will be trained. |
private java.util.List<Pattern> |
patterns
The patterns set provided by the user. |
private boolean |
pruneAfterTraining
If set to true the net will be pruned after training. |
private double |
pruningLimit
If pruning is activated, all weights smaller than pruningLimit will be set to 0 after training. |
private java.util.List<Pattern> |
trainingPatterns
This set of patterns is extracted from patterns. |
private double |
validationCoeff
If the value of this property is > 0, the provided pattern set will be split into a training set and a validation set. |
private double |
validationError
The last calculated validation error of the net. |
private java.util.List<Pattern> |
validationPatterns
This set of patterns is extracted from patterns. |
| Constructor Summary | |
|---|---|
Trainer(FeedForwardNet net,
TeachingAlgorithm algorithm)
Creates a new Trainer with default parameters. |
|
| Method Summary | |
|---|---|
void |
addPatterns(int index,
Pattern pattern)
Inserts the specified element at the specified position in this list (optional operation) |
boolean |
addPatterns(Pattern pattern)
Appends the specified element to the end of this list (optional operation). |
protected double |
calculateError(java.util.List<Pattern> patterns)
Calculates the network's error on the provided pattern set. |
void |
clearPatterns()
Removes all of the elements from this list (optional operation). |
boolean |
containsAllPatterns(java.util.Collection patterns)
Returns true if this list contains all of the elements of the specified collection. |
boolean |
containsPatterns(Pattern pattern)
Returns true if this list contains the specified element. |
TeachingAlgorithm |
getAlgorithm()
Getter of the property algorithm |
int |
getBatchsize()
Getter of the property batchsiye |
int |
getCycle()
|
double |
getError()
Getter of the property error |
double |
getFlatspot()
Getter of the property flatspot |
double |
getLearningRate()
Getter of the property learningRate |
int |
getMaxCycles()
Getter of the property maxCycles |
double |
getMaxError()
Getter of the property maxError |
double |
getMomentum()
Getter of the property momentum |
FeedForwardNet |
getNet()
Getter of the property net |
java.util.List<Pattern> |
getPatterns()
Getter of the property patterns |
Pattern |
getPatterns(int i)
Returns the element at the specified position in this list. |
boolean |
getPruneAfterTraining()
Getter of the property pruneAfterTraining |
double |
getPruningLimit()
Getter of the property pruningLimit |
double |
getValidationCoeff()
Getter of the property validationCoeff |
double |
getValidationError()
Getter of the property validationError |
boolean |
isPatternsEmpty()
Returns true if this list contains no elements. |
java.util.Iterator<Pattern> |
patternsIterator()
Returns an iterator over the elements in this list in proper sequence. |
int |
patternsSize()
Returns the number of elements in this list. |
Pattern[] |
patternsToArray()
Returns an array containing all of the elements in this list in proper sequence. |
Pattern[] |
patternsToArray(Pattern[] patterns)
Returns an array containing all of the elements in this list in proper sequence; the runtime type of the returned array is that of the specified array. |
private void |
pruneNet()
The weight of all connections with a weight lower than pruning limit are set to zero. |
java.lang.Object |
removePatterns(int index)
Removes the element at the specified position in this list (optional operation). |
boolean |
removePatterns(Pattern pattern)
Removes the first occurrence in this list of the specified element (optional operation). |
void |
setAlgorithm(TeachingAlgorithm algorithm)
Setter of the property algorithm |
void |
setBatchsize(int batchsize)
Setter of the property batchsiye |
void |
setFlatspot(double flatspot)
Setter of the property flatspot |
void |
setLearningRate(double learningRate)
Setter of the property learningRate |
void |
setMaxCycles(int maxCycles)
Setter of the property maxCycles |
void |
setMaxError(double maxError)
Setter of the property maxError |
void |
setMomentum(double momentum)
Setter of the property momentum |
void |
setNet(FeedForwardNet net)
Setter of the property net |
void |
setPatterns(java.util.List<Pattern> patterns)
Setter of the property patterns |
void |
setPruneAfterTraining(boolean pruneAfterTraining)
Setter of the property pruneAfterTraining |
void |
setPruningLimit(double pruningLimit)
Setter of the property pruningLimit |
void |
setValidationCoeff(double validationCoeff)
Setter of the property validationCoeff |
void |
train()
train() controls the training process, which consists of the following steps: extract the training and validation sets from the provided pattern set run the teaching algorithm until the error is small enough (pattern sets are shuffled before each run run retraining until the validation error is small enough (optional) prune the net (optional) |
| Methods inherited from class java.lang.Object |
|---|
clone, equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait |
| Field Detail |
|---|
private org.apache.log4j.Logger log
private FeedForwardNet net
private java.util.List<Pattern> validationPatterns
patterns.size() * validationCoeff
private java.util.List<Pattern> trainingPatterns
patterns.size() * (1 - validationCoeff)
private java.util.List<Pattern> patterns
private TeachingAlgorithm algorithm
private double learningRate
private double maxError
private int maxCycles
private int cycle
private double momentum
private double validationCoeff
validationCoeff * patterns.size()
private double flatspot
private int batchsize
private boolean pruneAfterTraining
private double pruningLimit
private double error
private double validationError
| Constructor Detail |
|---|
public Trainer(FeedForwardNet net,
TeachingAlgorithm algorithm)
net - the FeedForwardNet to trainalgorithm - the algorithm to use for training| Method Detail |
|---|
public FeedForwardNet getNet()
public void setNet(FeedForwardNet net)
net - The net to set.public java.util.List<Pattern> getPatterns()
public Pattern getPatterns(int i)
index - index of element to return.
List.get(int)public java.util.Iterator<Pattern> patternsIterator()
List.iterator()public boolean isPatternsEmpty()
List.isEmpty()public boolean containsPatterns(Pattern pattern)
element - element whose presence in this list is to be tested.
List.contains(Object)public boolean containsAllPatterns(java.util.Collection patterns)
elements - collection to be checked for containment in this list.
List.containsAll(Collection)public int patternsSize()
List.size()public Pattern[] patternsToArray()
List.toArray()public Pattern[] patternsToArray(Pattern[] patterns)
a - the array into which the elements of this list are to be stored.
List.toArray(Object[])
public void addPatterns(int index,
Pattern pattern)
index - index at which the specified element is to be inserted.element - element to be inserted.List.add(int,Object)public boolean addPatterns(Pattern pattern)
element - element to be appended to this list.
List.add(Object)public java.lang.Object removePatterns(int index)
index - the index of the element to removed.
List.remove(int)public boolean removePatterns(Pattern pattern)
element - element to be removed from this list, if present.
List.remove(Object)public void clearPatterns()
List.clear()public void setPatterns(java.util.List<Pattern> patterns)
patterns - the patterns to set.public TeachingAlgorithm getAlgorithm()
public void setAlgorithm(TeachingAlgorithm algorithm)
algorithm - The algorithm to set.public void train()
train() controls the training process, which consists of the following steps:
private void pruneNet()
public double getLearningRate()
public void setLearningRate(double learningRate)
learningRate - The learningRate to set.public double getMaxError()
public void setMaxError(double maxError)
maxError - The maxError to set.public int getMaxCycles()
public void setMaxCycles(int maxCycles)
maxCycles - The maxCycles to set.protected double calculateError(java.util.List<Pattern> patterns)
patterns - pattern setpublic int getCycle()
public double getMomentum()
public void setMomentum(double momentum)
momentum - The momentum to set.public double getValidationCoeff()
public void setValidationCoeff(double validationCoeff)
validationCoeff - The validationCoeff to set.public double getFlatspot()
public void setFlatspot(double flatspot)
flatspot - The flatspot to set.public int getBatchsize()
public void setBatchsize(int batchsize)
batchsiye - The batchsiye to set.public boolean getPruneAfterTraining()
public void setPruneAfterTraining(boolean pruneAfterTraining)
pruneAfterTraining - The pruneAfterTraining to set.public double getPruningLimit()
public void setPruningLimit(double pruningLimit)
pruningLimit - The pruningLimit to set.public double getError()
public double getValidationError()
|
||||||||||
| PREV CLASS NEXT CLASS | FRAMES NO FRAMES | |||||||||
| SUMMARY: NESTED | FIELD | CONSTR | METHOD | DETAIL: FIELD | CONSTR | METHOD | |||||||||