function [logZZ_est, logZZ_est_up, logZZ_est_down] = ...
RBM_AIS(vishid,hidbiases,visbiases,numruns,beta,batchdata);
% Version 1.000
%
% Code provided by Ruslan Salakhutdinov
%
% Permission is granted for anyone to copy, use, modify, or distribute this
% program and accompanying programs and documents for any purpose, provided
% this copyright notice is retained and prominently displayed, along with
% a note saying that the original programs are available from our
% web page.
% The programs and documents are distributed without any warranty, express or
% implied. As the programs were written for research purposes only, they have
% not been tested to the degree that would be advisable in any important
% application. All use of these programs is entirely at the user's own risk.
% vishid -- a matrix of RBM weights [numvis, numhid]
% hidbiases -- a row vector of hidden biases [1 numhid]
% visbiases -- a row vector of visible biases [1 numvis]
% numruns -- number of AIS runs
% beta -- a row vector containing beta's
% batchdata -- the data that is divided into batches (numcases numdims numbatches)
% Note: The training data, batchdata, is only used to create base-rate model.
% If batchdata is not present, initial distribution will be uniform.
% Thanks to Nicolas Le Roux for pointing out ways of making this code faster.
close all
figure('Position',[100,600,500,200]);
figure(2)
hold on
xlabel('beta','fontsize',14)
ylabel('Variance of log weights','fontsize',14)
[numdims numhids]=size(vishid);
if(nargin>5)
%%% Initialize biases of the base rate model by ML %%%%%%%%%%%%%%%%%%%%%%%
base_rate
visbiases_base = log_base_rate';
else
visbiases_base = 0*visbiases;
end
numcases = numruns;
%%%%%%%%%% RUN AIS %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
visbias_base = repmat(visbiases_base,numcases,1); %biases of base-rate model.
hidbias = repmat(hidbiases,numcases,1);
visbias = repmat(visbiases,numcases,1);
%%%% Sample from the base-rate model %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
logww = zeros(numcases,1);
negdata = repmat(1./(1+exp(-visbiases_base)),numcases,1);
negdata = negdata > rand(numcases,numdims);
logww = logww - (negdata*visbiases_base' + numhids*log(2));
Wh = negdata*vishid + hidbias;
Bv_base = negdata*visbiases_base';
Bv = negdata*visbiases';
tt=1;
%%% The CORE of an AIS RUN %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
for bb = beta(2:end-1);
fprintf(1,'beta=%d\r',bb);
tt = tt+1;
expWh = exp(bb*Wh);
logww = logww + (1-bb)*Bv_base + bb*Bv + sum(log(1+expWh),2);
poshidprobs = expWh./(1 + expWh);
poshidstates = poshidprobs > rand(numcases,numhids);
negdata = 1./(1 + exp(-(1-bb)*visbias_base - bb*(poshidstates*vishid' + visbias)));
negdata = negdata > rand(numcases,numdims);
if rem(tt,500)==0
figure(1)
mnistdisp(negdata(1:10,:)');
figure(2)
plot(tt/length(beta),var(logww(:)),'b*')
hold on
drawnow;
end
Wh = negdata*vishid + hidbias;
Bv_base = negdata*visbiases_base';
Bv = negdata*visbiases';
expWh = exp(bb*Wh);
logww = logww - ((1-bb)*Bv_base + bb*Bv + sum(log(1+expWh),2));
end
expWh = exp(Wh);
logww = logww + negdata*visbiases' + sum(log(1+expWh),2);
%%% Compute an estimate of logZZ_est +/- 3 standard deviations.
r_AIS = logsum(logww(:)) - log(numcases);
aa = mean(logww(:));
logstd_AIS = log (std(exp ( logww-aa))) + aa - log(numcases)/2;
%%% Same as computing logstd_AIS = log(std(exp(logww(:)))/sqrt(numcases));
logZZ_base = sum(log(1+exp(visbiases_base))) + (numhids)*log(2);
logZZ_est = r_AIS + logZZ_base;
logZZ_est_up = logsum([log(3)+logstd_AIS;r_AIS]) + logZZ_base;
logZZ_est_down = logdiff([(log(3)+logstd_AIS);r_AIS]) + logZZ_base;
if ~isreal(logZZ_est_down)
logZZ_lat_comp_down = 0;
end