/* Copyright (c) 2024 Krypto-IT Jakub Juszczakiewicz
 * All rights reserved.
 */

// Richardson–Lucy deconvolution implementation for 1D signal: 16 PCM audio

// compile: gcc -O3 -fopenmp -o lucy lucy.c -lm -lfftw3 -lfftw3_threads

#include <stdlib.h>
#include <stdio.h>
#include <string.h>
#include <math.h>
#include <fftw3.h>
#include <omp.h>

#define PADDING 64
#define ITERS 64

struct wav_header
{
  char riff[4];
  unsigned int fsize;
  char wave_fmt[8];
  unsigned int dsize;
  unsigned short format;
  unsigned short channels;
  unsigned int sample_rate;
  unsigned int data_rate;
  unsigned short sample_size;
  unsigned short sample_bits;
  char data_str[4];
  unsigned int section_size;
  union {
    short s16[16]; // 16b
    float f32[8]; // 32b
  };
};

unsigned int load_file(const char * path, struct wav_header ** output)
{
  FILE * f = fopen(path, "rb");
  if (!f)
    return 0;
  fseek(f, 0, SEEK_END);
  int size = ftell(f);
  rewind(f);
  *output = malloc(size);
  if (fread(*output, 1, size, f) != size) {
    fclose(f);
    free(*output);
    *output = NULL;
    return 0;
  }
  fclose(f);
  return size;
}

void store_file(const char * path, void * data, unsigned int size)
{
  FILE * f = fopen(path, "wb");
  fwrite(data, 1, size, f);
  fclose(f);
}

void pcm2doubles(double * output, const short * input, size_t probes,
    size_t channels)
{
  for (size_t i = 0; i < probes; i++) {
    output[i] = ((double)input[i * channels] / 32768.);
  }
}

void doubles2pcm(short * output, const double * input, size_t probes,
    size_t channels)
{
  for (size_t i = 0; i < probes; i++) {
    int val = round(input[i] * 32768.);
    if (val > 32767)
      val = 32767;
    else if (val < -32768)
      val = -32768;
    output[i * channels] = val;
  }
}

float float2doubles(double * output, const float * input, size_t probes,
    size_t channels)
{
  float max = input[0];
  for (size_t i = 0; i < probes; i++) {
    if (max < input[i * channels])
      max = input[i * channels];
  }

  for (size_t i = 0; i < probes; i++) {
    output[i] = 0.75 * (input[i * channels] / max);
  }

  return max;
}

void doubles2float(float * output, const double * input, size_t probes,
    size_t channels, float max)
{
  double dmax = input[0];
  for (size_t i = 0; i < probes; i++) {
    if (dmax < input[i])
      dmax = input[i];
  }
  dmax = max / dmax;

  for (size_t i = 0; i < probes; i++) {
    output[i * channels] = input[i] * dmax;
  }
}

static void complex_mul(fftw_complex * a, fftw_complex b)
{
  double tmp1 = (*a)[0], tmp2 = (*a)[1];
  (*a)[0] = tmp1 * b[0] - tmp2 * b[1];
  (*a)[1] = tmp1 * b[1] + tmp2 * b[0];
}

void lucy(double * data, size_t probes, size_t iters)
{
  fftw_complex * fft_1 = fftw_malloc(sizeof(fftw_complex) * (probes / 2 + 1));
  fftw_complex * kernel = fftw_malloc(sizeof(fftw_complex) *
      (probes / 2 + 1));
  double * tmp1 = fftw_malloc(sizeof(double) * probes);
  double * tmp2 = fftw_malloc(sizeof(double) * probes);

  fftw_plan p1 = fftw_plan_dft_r2c_1d(probes, tmp1, fft_1, FFTW_ESTIMATE);
  fftw_plan p2 = fftw_plan_dft_c2r_1d(probes, fft_1, tmp2, FFTW_ESTIMATE);
  fftw_plan p3 = fftw_plan_dft_r2c_1d(probes, tmp2, fft_1, FFTW_ESTIMATE);

  for (size_t i = 0; i < probes; i++) {
    tmp1[i] = sin((M_PI * i) / (probes * 32));
  }

  fftw_execute(p1);
  memcpy(kernel, fft_1, sizeof(fftw_complex) * (probes / 2 + 1));
  memcpy(tmp1, data, sizeof(double) * probes);

  double min, max;

  for (size_t i = 0; i < iters; i++) {
    fftw_execute(p1);
#pragma omp parallel for
    for (size_t j = 0; j < probes / 2 + 1; j++) {
      complex_mul(&fft_1[j], kernel[j]);
    }
    fftw_execute(p2);
    min = tmp2[0];
    max = min;
    for (size_t j = 0; j < probes; j++) {
      tmp2[j] = data[j] / tmp2[j];
      if (tmp2[j] < min)
        min = tmp2[j];
      else if (tmp2[j] > max)
        max = tmp2[j];
    }
#pragma omp parallel for
    for (size_t j = 0; j < probes; j++) {
      tmp2[j] = (tmp2[j] - min) * 2 / (max - min) - 1;
    }
    min = tmp2[0];
    max = min;

    for (size_t j = 0; j < probes; j++) {
      if (tmp2[j] < min)
        min = tmp2[j];
      else if (tmp2[j] > max)
        max = tmp2[j];
    }
    fftw_execute(p3);
#pragma omp parallel for
    for (size_t j = 0; j < probes / 2 + 1; j++) {
      complex_mul(&fft_1[j], kernel[j]);
    }
    fftw_execute(p2);
    min = tmp1[0];
    max = min;
    for (size_t j = 0; j < probes; j++) {
      tmp1[j] *= tmp2[j];
      if (tmp1[j] < min)
        min = tmp1[j];
      else if (tmp1[j] > max)
        max = tmp1[j];
    }
#pragma omp parallel for
    for (size_t j = 0; j < probes; j++) {
      tmp1[j] = (tmp1[j] - min) * 2 / (max - min) - 1;
    }
    min = tmp1[0];
    max = min;
    for (size_t j = 0; j < probes; j++) {
      if (tmp1[j] < min)
        min = tmp1[j];
      else if (tmp1[j] > max)
        max = tmp1[j];
    }
  }

  memcpy(data, tmp1, sizeof(double) * probes);

  fftw_destroy_plan(p1);
  fftw_destroy_plan(p2);
  fftw_destroy_plan(p3);
  fftw_free(fft_1);
  fftw_free(kernel);
  fftw_free(tmp1);
  fftw_free(tmp2);
}

int main(int argc, char * argv[])
{
  if (argc < 3)
    return 1;

  struct wav_header * wave;
  unsigned int fsize = load_file(argv[1], &wave);
  if (!fsize)
    return 1;

  fftw_init_threads();
  fftw_plan_with_nthreads(omp_get_max_threads());

  if ((wave->format == 1) && (wave->sample_bits == 16)) {
    size_t probes = wave->section_size / (wave->channels * 2);
    short * in = wave->s16;
    double * input = fftw_malloc(sizeof(double) * (probes + 2 * PADDING));
    memset(input, 0, sizeof(double) * PADDING);
    memset(input + PADDING + probes, 0, sizeof(double) * PADDING);

   for (size_t i = 0; i < wave->channels; i++) {
      pcm2doubles(&input[PADDING], &in[i], probes,
          wave->channels);
      lucy(input, probes + 2 * PADDING, ITERS);
      doubles2pcm(&in[i], &input[PADDING], probes, wave->channels);
    }

    fftw_free(input);
  } else if ((wave->format == 3) && (wave->sample_bits == 32)) {
    int * sptr = (int *)&wave->data_str[18];
    float * in = (float *)&wave->data_str[22];
    size_t probes = *sptr / (wave->channels * 4);
    double * input = fftw_malloc(sizeof(double) * (probes + 2 * PADDING));
    memset(input, 0, sizeof(double) * PADDING);
    memset(input + PADDING + probes, 0, sizeof(double) * PADDING);

    for (size_t i = 0; i < wave->channels; i++) {
      float max = float2doubles(&input[PADDING], &in[i], probes,
          wave->channels);
      lucy(input, probes + 2 * PADDING, ITERS);
      doubles2float(&in[i], &input[PADDING], probes, wave->channels,
          max);
    }
    fftw_free(input);
  } else {
    printf("%d %d\n", wave->format, wave->sample_bits);
    free(wave);
    fftw_cleanup_threads();
    return 1;
  }


  store_file(argv[2], wave, fsize);

  free(wave);
  fftw_cleanup_threads();

  return 0;
}
