# YUV 原画から JM を使ってエンコードを行い、RD と実行時間を確認するスクリプト

src = [
'susie',
'football',
'mobile'
]

def get_snr(filename)
    snr = 0.0
    open(filename) do |fd|
        while (line = fd.gets()) != nil do
            if line['avg'] then
                snr = line.split("\t").at(1).to_f
            end
        end
    end
    return snr
end

def get_bps(filename, count, fps)
    return FileTest.size(filename) * 8 * fps / count
end

def test(tag, width, height, fps, count, qp, gop, bfrm, cabac, cost_function, pred, t8x8, matrix, mbaff, rdo)

    # コマンドラインをでっち上げる
    # 基本パラメータ
    opt  = sprintf(" -p InputFile=yuv\\%s.420 -p SourceWidth=%d -p SourceHeight=%d -p FrameRate=%.4f", tag, width, height, fps)
    
    # QP 関連
    opt <<  sprintf(" -p QPISlice=%d -p QPPSlice=%d -p QPBSlice=%d", qp, qp, qp)

    # GOP 構造関連
    if (bfrm > 0) then
        opt << sprintf(" -p FramesToBeEncoded=%d -p IntraPeriod=%d -p FrameSkip=%d -p NumberBFrames=%d -p NumberReferenceFrames=2 -p LastFrameNumber=%d", count/(bfrm+1), gop/(bfrm+1), bfrm, bfrm, count-1)
    else
        opt << sprintf(" -p FramesToBeEncoded=%d -p IntraPeriod=%d -p FrameSkip=0 -p NumberBFrames=0 -p NumberReferenceFrames=1", count, gop)
    end
    
    # エントロピー符号化
    if (cabac > 0) then
        opt << sprintf(" -p SymbolMode=%d", 1)
        opt << sprintf(" -p ContextInitMethod=%d -p FixedModelNumber=%d", 0, 0)
    else
        opt << sprintf(" -p SymbolMode=%d", 0)
    end
    
    # コスト評価関数
    if (cost_function > 0) then
        opt << sprintf(" -p MEDistortionFPel=%d -p MEDistortionHPel=%d -p MEDistortionQPel=%d -p MDDistortion=%d", 0, 2, 2, 2)
    else
        opt << sprintf(" -p MEDistortionFPel=%d -p MEDistortionHPel=%d -p MEDistortionQPel=%d -p MDDistortion=%d", 0, 0, 0, 0)
    end
    
    # インター予測ブロックサイズ
    case pred
    when 0 # 16x16 だけ
        opt << sprintf(" -p InterSearch16x16=%d -p InterSearch16x8=%d  -p InterSearch8x16=%d  -p InterSearch8x8=%d", 1, 0, 0, 0)
    when 1 # 16x16 と 8x8 だけ
        opt << sprintf(" -p InterSearch16x16=%d -p InterSearch16x8=%d  -p InterSearch8x16=%d  -p InterSearch8x8=%d", 1, 0, 0, 1)
    when 2 # 16x16 と 16x8/8x16 だけ
        opt << sprintf(" -p InterSearch16x16=%d -p InterSearch16x8=%d  -p InterSearch8x16=%d  -p InterSearch8x8=%d", 1, 1, 1, 0)
    when 3 # 16x16 〜 8x8 まで
        opt << sprintf(" -p InterSearch16x16=%d -p InterSearch16x8=%d  -p InterSearch8x16=%d  -p InterSearch8x8=%d", 1, 1, 1, 1)
    end
    # 8x4/4x8 と 4x4 は使わない
    opt << sprintf(" -p InterSearch8x4=%d -p InterSearch4x8=%d -p InterSearch4x4=%d", 0, 0, 0)
    
    # 8x8 変換
    if (t8x8 > 0) then
        opt << sprintf(" -p Transform8x8Mode=%d", 1)
    else
        opt << sprintf(" -p Transform8x8Mode=%d", 0)
    end
    
    # 量子化マトリックス
    if (matrix > 0) then
        # JVT default
        opt << sprintf(" -p ScalingMatrixPresentFlag=%d -p ScalingListPresentFlag0=%d -p ScalingListPresentFlag1=%d -p ScalingListPresentFlag2=%d -p ScalingListPresentFlag3=%d -p ScalingListPresentFlag4=%d -p ScalingListPresentFlag5=%d -p ScalingListPresentFlag6=%d -p ScalingListPresentFlag7=%d", 1, 0, 0, 0, 0, 0, 0, 0, 0)
    else
        # flat 16
        opt << sprintf(" -p ScalingMatrixPresentFlag=%d -p ScalingListPresentFlag0=%d -p ScalingListPresentFlag1=%d -p ScalingListPresentFlag2=%d -p ScalingListPresentFlag3=%d -p ScalingListPresentFlag4=%d -p ScalingListPresentFlag5=%d -p ScalingListPresentFlag6=%d -p ScalingListPresentFlag7=%d", 0, 0, 0, 0, 0, 0, 0, 0, 0)
    end
    
    # MBAFF 関連
    if (mbaff > 0) then
        # MBAFF 有効 (適応判定)
        opt << sprintf(" -p PicInterlace=%d -p MbInterlace=%d", 0, 2)
        rdo = 1 # この場合 rdo=1 でなければエンコードできない (JM の制限)
    else
        # MBAFF 無効
        opt << sprintf(" -p PicInterlace=%d -p MbInterlace=%d", 0, 0)
    end
    
    # rdo 関連
    if (rdo > 0) then
        # RDO 有効 (slow)
        opt << sprintf(" -p RDOptimization=%d", 1)
    else
        # RDO 無効
        opt << sprintf(" -p RDOptimization=%d", 0)
    end

    sec = Time.now()

    cmd = sprintf("bin\\lencod %s > dummy.log", opt)
    $stderr.printf("%s\n", cmd)
    system(cmd)

    sec = Time.now() - sec
        
    cmd = "bin\\ldecod decoder.cfg > dummy.log"
    system(cmd)
    cmd = sprintf("bin\\yuvsnr -w %d -h %d yuv\\%s.420 test_dec.yuv > snr.log", width, height, tag)
    system(cmd)

    snr = get_snr("snr.log")
    bps = get_bps("test.264", count, fps)

    printf("%d\t%f\t%f\t%f\n", qp, bps, snr, sec.to_f)

    File.unlink("test.264")
    File.unlink("test_dec.yuv")
    File.unlink("dataDec.txt")
    File.unlink("log.dec")
    File.unlink("snr.log")
    File.unlink("dummy.log")
    File.unlink("stats.dat")
    File.unlink("log.dat")
    File.unlink("test_rec.yuv")
    File.unlink("data.txt")

end

src.each do |tag|
    width = 720
    height = 480
    fps = 29.97
    count = 260
    gop = 15
    bfrm = 2
    cabac = 1
    pred = 3
    t8x8 = 1
    matrix = 0
    mbaff = 0
    [0, 1].each do |cost_function|
        [0, 1].each do |rdo|
            printf("tag=%s, gop%d, bfrm=%d, cabac=%d, cost_function=%d, pred=%d, t8x8=%d, matrix=%d, mbaff=%d, rdo=%d\n", tag, gop, bfrm, cabac, cost_function, pred, t8x8, matrix, mbaff, rdo)
            printf("qp\tbps\tsnr\tsec\n")
            18.step(32, 2) do |qp|
                test(tag, width, height, fps, count, qp, gop, bfrm, cabac, cost_function, pred, t8x8, matrix, mbaff, rdo)
            end
            printf("\n")
        end
    end
end
