function [s,l,conv]=pexpsurv_kk(t,delta,g0,r,knots,f,pw)
% MATLAB function for kin-cohort survival estimation. Piecewise exponential models are used
% for survial function.
%
% INPUT ARGUMENT (all vectors should be column vector)
%      t = time to event 
%  delta = 0-1 indicator of failure
%     g0 = carrier status of proband (1=non-carriers, 2=carriers)
%  knots = vector of knots for piecewise constant model of hazards 
%          (the first knot should not be smaller than the time of the first event
%           and the last knot should not be higher than the time of the last event)
%      r =  relative type (1=parents/child, 2=siblings)
%      f = allele frequency of mutation
%     pw = prior weights for relatives (assume 1 if no weights are needed)
%
% OUTPUT ARGUMENT
%      s=  K by 2 matrix containing estimates of survival probabilities (K=number of knots)
%          First column corresponds to estimates for non-carriers and second column corresponds 
%          to estimates for carriers
%      l = (K+1) by 2 matrix containing estimates of hazards (K=number of knots)
%          First column corresponds to estimates for non-carriers and second column corresponds 
%          to estimates for carriers
%   conv = indicator of whether the EM algorithm converges (1=converged)



[spool,lpool] = pexp(t,delta,ones(length(t),1),knots);
nknots=[0;knots;100000];
tc= bin2(t,nknots);
ltc=tc;
ltc=tc-1;
l0= [lpool, lpool];
s0=[spool, spool];
surv = zeros(length(t),2);
haz = zeros(length(t),2);
l = l0;
s = s0;
rerror = 1;
niter = 1;
conv = 0;
tp = cp1exact(f);

while  (rerror > 1e-005) & (niter < 1000) 
   surv(tc>1,1) = s(ltc(tc>1),1).*exp(-l(tc(tc>1),1).*(t(tc>1)-nknots(tc(tc>1))+0.5));
   surv(tc>1,2) = s(ltc(tc>1),2).*exp(-l(tc(tc>1),2).*(t(tc>1)-nknots(tc(tc>1))+0.5));
   
   if (any(tc==1))
   surv(tc==1,1)=exp(-l(tc(tc==1),1).*(t(tc==1)-nknots(tc(tc==1))+0.5));
   surv(tc==1,2)=exp(-l(tc(tc==1),2).*(t(tc==1)-nknots(tc(tc==1))+0.5));
   end
   haz(:,1)=l(tc,1);
   haz(:,2)=l(tc,2);
   
   fv(:, 1)= delta .* haz(:,1).* surv(:, 1) + (1 - delta) .* surv(:, 1);
	fv(:, 2)= delta .* haz(:,2) .* surv(:, 2) + (1 - delta) .* surv(:, 2);
	w11 = ((2-r).*tp(1,1,1)+(r-1).*tp(1,1,2)).* fv(:, 1);
	w12 = ((2-r).*tp(2,1,1)+(r-1).*tp(2,1,2)).* fv(:, 1);
	w21 = ((2-r).*tp(1,2,1)+(r-1).*tp(1,2,2)).* fv(:, 2);
	w22 = ((2-r).*tp(2,2,1)+(r-1).*tp(2,2,2)).* fv(:, 2);
   
   w11 = w11./(w11 + w21);
   w21 = 1 - w11;
	w12 = w12./(w12 + w22);
   w22 = 1- w12;
   
   tw1= w11 .* (2 - g0) + w12 .* (g0 - 1);
	[shat1,lhat1]=pexp(t,delta,tw1,knots);
   s(:,1)=shat1; l(:,1)=lhat1;
   
   tw2= w21 .* (2 - g0) + w22 .* (g0 - 1);
   [shat2,lhat2]=pexp(t, delta, tw2,knots);
   s(:,2)=shat2; l(:,2)=lhat2;
   
   rerror= max([abs(s(:, 1) - s0(:, 1))./max(abs(s0(:, 1)), 0.1);...
    abs(s(:, 2) - s0(:,2))./max(abs(s0(:, 2)), 0.1)]);
	l0 = l;
	s0 = s;
   niter = niter + 1;
   %rerror 
   
end

   if  rerror < 1e-005
       conv= 1;
   end
