Modified metricclust script to use reference spike trains to classify another set of spike trains

The following script allows to use one set of spike trains for reference spike trains and another set for classify spike trains to cluster spike train based on distance matrix, instead of the leave-one-out method (using just one set of spike trains) used in metricclust.m. This may be useful e.g. if you want to test whether spike trains encoding remains the same when some external conditions are changed : you could use spike trains from the 'default' condition as reference spike trains, and classify spike trains from another condition to test whether spike train clustering is maintained throughout conditions. The clustering method is the same as in metricclust. Use metricdist beforehand to compute distances between all pairs of spike trains from all conditions. This is somehow suboptimal since we only need a subset of those pairs, that is all distances between one reference spike train and one classify spike train. Then just select the appropriate sub-matrix as first argument to metricclust_ref. Anyone of good will to re-write metricdist to compute only relevant distances is welcome…

Script provided with no warranty or promises. Please email alexandre.hyafil (at) for any bug report / comment / suggestion.

     function CM = metricclust_ref(D, cat_ref, cat, M, z)
%CM = metricclust_ref(D, cat_ref, cat, M) uses a simple clustering
%    method to classify spike trains based on the distance matrix
%    D between the spike trains to classify and reference spike trains
%   (use rows for reference spike trains, columns for classified spike
%    trains)
%   cat_ref is a vector that gives the category indices of the
%   reference spike trains.
%   cat is a vector that gives the category indices of the
%   to-be-classified spike trains. 
%   M is the number of categories. 
%   CM = metricclust_ref(D, cat_ref, cat, M, z) uses user-defined clustering
%   exponent z (default : -2)
%   CM is a square matrix
%    where the columns correspond to the actual classes and the rows
%   correspond to the assigned classes.
%   for any question/comment/bug correction, please email Alexandre Hyafil :
%   alexandre.hyafil (at)
%See also metricclust from Spike Train Analysis toolkit
nref = length(cat_ref);  %number of reference spike trains
ncls = length(cat);      %number of classify spike trains
if ~isequal(size(D), [nref ncls]),
   error('D should be a number of reference spike trains x number of classify spike trains matrix'); 
if nargin < 5,
    z = -2;
elseif isstruct(z),  %extract clustering exponent from opts structure
    z = z.clustering_exponent;
%initialize classified vectors
%cls = nan(1,ncls);
%initialize confusion matrix
CM = zeros(M,M);
%find all reference spike trains belonging to same category
grp_ref = cell(1,M);
for g=1:M,
    grp_ref{g} = (cat_ref==g-1);
ng_ref = cellfun(@length, grp_ref); %number of reference spike trains belonging to each category
%find zero-distance values
zeromat = (D==0);
%proportion of zero-distance values for each reference category
zeromat_grp = zeros(M,ncls);
for g=1:M
   zeromat_grp(g,:) = sum(zeromat(grp_ref{g},:)) / ng_ref(g);
%for all classify spike trains with at least one zero-distance value
zerocls = find(any(zeromat_grp>0));
maxzero = max(zeromat_grp(:,zerocls), [], 1); %maximum proportion of zero-distance for each of these spike trains
for c=1:length(zerocls)
    %which category(ies) reach the maximum value
    dec = find( zeromat_grp(:,zerocls(c)) == maxzero(c) ); 
    %add 1 to the maximum value in CM (or 1/n if n values reaching the
    CM(dec, cat(zerocls(c))+1) = CM(dec,cat(zerocls(c))+1) + 1/length(dec);
%for all other classify spike trains
nozerocls = setdiff( 1:ncls, zerocls); 
%elevate distance to power z
Dz = D(:,nozerocls).^z;
%sum over each reference category, elevate to power 1/z and normalize to
%number of spike trains in each category
S = zeros(M, length(nozerocls));
for g=1:M
    S(g,:) = sum(Dz(grp_ref{g},:)).^(1/z) / ng_ref(g);
%and add to CM the category with minimal distance
mindist = min(S);
for c=1:length(nozerocls)
    %which category(ies) reach the maximum value
    dec = find( S(:,c) == mindist(c) ); 
    %add 1 to the maximum value in CM (or 1/n if n values reaching the
    CM(dec, cat(nozerocls(c))+1) = CM(dec,cat(nozerocls(c))+1) + 1/length(dec);
usage/tutorials/cluster_with_reference_spike_trains.txt · Last modified: 2012/10/01 06:32 by Alexandre Hyafil