/*******************************************************************
                 Program Stream treating module
 *******************************************************************/

#include <stdlib.h>
#include "pes.h"

#define PROGRAM_STREAM_C
#include "program_stream.h"

typedef struct {
	
	BITSTREAM      *bs;
	PES_STREAM_TYPE type;

	unsigned char   buffer[MAX_PACKET_DATA_LENGTH];
	unsigned char  *current;

	unsigned int    data_rest;
	int             packet_start_code;
	
} PROGRAM_STREAM;

int ps_open(const char *filename, int stream_type);
int __cdecl ps_close(int in);
int __cdecl ps_read(int in, void *data, unsigned int count);
__int64 __cdecl ps_seek(int in, __int64 offset, int origin);
__int64 __cdecl ps_tell(int in);

int ps_open(const char *filename, int stream_type)
{
	PROGRAM_STREAM *ps;
	PES_PACKET packet;
	PES_STREAM_TYPE type;

	ps = (PROGRAM_STREAM *)calloc(1, sizeof(PROGRAM_STREAM));
	if(ps == NULL){
		return -1;
	}
	
	ps->bs = bs_open(filename);
	if(ps->bs == NULL){
		return -1;
	}

	ps->type.type = stream_type;
	
	while(read_pes_packet(ps->bs, &packet)){
		if(extract_pes_stream_type(&packet, &type)){
			if(type.type == ps->type.type){
				ps->type.id = type.id;
				bs_seek(ps->bs, 0, SEEK_SET);
				ps->packet_start_code = 0x100 + packet.stream_id;
				return (int)ps;
			}
		}
	}

	bs_close(ps->bs);

	return -1;
}

int __cdecl ps_close(int in)
{
	PROGRAM_STREAM *ps;

	ps = (PROGRAM_STREAM *)in;

	if(ps){
		if(ps->bs){
			bs_close(ps->bs);
		}
		free(ps);
	}

	return 1;
}

int __cdecl ps_read(int in, void *data, unsigned int count)
{
	int r;
	
	PROGRAM_STREAM *ps;
	PES_PACKET packet;
	PES_STREAM_TYPE type;

	ps = (PROGRAM_STREAM *)in;

	if(ps->data_rest){
		memcpy(data, ps->current, ps->data_rest);
		r = ps->data_rest;
		ps->data_rest = 0;
		return r;
	}

	while(read_pes_packet(ps->bs, &packet) ){
		if( extract_pes_stream_type(&packet, &type) ){
			if( (type.type == ps->type.type) && (type.id == ps->type.id) ){
				extract_pes_packet_data(&packet, ps->buffer, &(ps->data_rest) );
				if(ps->data_rest <= count){
					memcpy(data, ps->buffer, ps->data_rest);
					r = ps->data_rest;
					ps->data_rest = 0;
					return r;
				}else{
					memcpy(data, ps->buffer, count);
					ps->data_rest -= count;
					ps->current = ps->buffer + count;
					return count;
				}
			}
		}
	}
	
	return 0;
}

__int64 __cdecl ps_seek(int in, __int64 offset, int origin)
{
	__int64 n;
	PROGRAM_STREAM *ps;
	PES_PACKET      packet;
	PES_STREAM_TYPE type;
	
	ps = (PROGRAM_STREAM *)in;

	n = bs_seek(ps->bs, offset, origin);

	while(bs_previous_packet_prefix(ps->bs)){
		if( bs_read_bits(ps->bs, 32) == ps->packet_start_code ){
			read_pes_packet(ps->bs, &packet);
			extract_pes_stream_type(&packet, &type);
			if(! ( type.type == ps->type.type && type.id == ps->type.id) ){
				break;
			}
			if( bs_tell(ps->bs) < n ){
				break;
			}
			extract_pes_packet_data(&packet, ps->buffer, &(ps->data_rest));
			ps->current = ps->buffer;
			return ps_tell(in);
		}else{
			ps->bs->current -= 4;
		}
	}

	while( read_pes_packet(ps->bs, &packet) ){
		if( extract_pes_stream_type(&packet, &type) ){
			if( (type.type == ps->type.type) && (type.id == ps->type.id) ){
				extract_pes_packet_data(&packet, ps->buffer, &(ps->data_rest));
				ps->current = ps->buffer;
				break;
			}
		}
	}

	return ps_tell(in);
}

__int64 __cdecl ps_tell(int in)
{
	__int64 r;

	PROGRAM_STREAM *ps;

	
	ps = (PROGRAM_STREAM *)in;

	r = bs_tell(ps->bs);
	r -= ps->data_rest;

	return r;
}

