ML.NET 2- 预测出租车价格

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/csharp25/article/details/84680443

1. 预备测试数据
2. 加载模型
3. 训练
4. 预测

实现:

TaxiFarePrediction.cs:
using System;
using System.Collections.Generic;
using System.IO;
using System.Text;
using System.Threading.Tasks;
using Microsoft.ML;
using Microsoft.ML.Data;
using Microsoft.ML.Models;
using Microsoft.ML.Trainers;
using Microsoft.ML.Transforms;

namespace _01_TaxiFare
{
    public class TaxiFarePrediction
    {
        static readonly string _datapath = Path.Combine(Environment.CurrentDirectory,  "taxi-fare-train.csv");
        static readonly string _testdatapath = Path.Combine(Environment.CurrentDirectory,  "taxi-fare-test.csv");
        static readonly string _modelpath = Path.Combine(Environment.CurrentDirectory, "Model.zip");

        public static async Task<TaxiTripFarePrediction> Predict(TaxiTrip tt)
        {
            var model = await Train();
            Evaluate(model);

            return model.Predict(tt);
        }

        private static async Task<PredictionModel<TaxiTrip, TaxiTripFarePrediction>> Train()
        {
            
            var pipeline = new LearningPipeline
            {
                new TextLoader(_datapath).CreateFrom<TaxiTrip>(useHeader: true, separator: ','),
                new ColumnCopier(("FareAmount", "Label")),
                new CategoricalOneHotVectorizer(
                    "VendorId",
                    "RateCode",
                    "PaymentType"),
                new ColumnConcatenator(
                    "Features",
                    "VendorId",
                    "RateCode",
                    "PassengerCount",
                    "TripDistance",
                    "PaymentType"),
                new FastTreeRegressor()
            };
            PredictionModel<TaxiTrip, TaxiTripFarePrediction> model = pipeline.Train<TaxiTrip, TaxiTripFarePrediction>();
            await model.WriteAsync(_modelpath);
            return model;
        }
        private static void Evaluate(PredictionModel<TaxiTrip, TaxiTripFarePrediction> model)
        {
            var testData = new TextLoader(_testdatapath).CreateFrom<TaxiTrip>(useHeader: true, separator: ',');
            var evaluator = new RegressionEvaluator();
            RegressionMetrics metrics = evaluator.Evaluate(model, testData);
            Console.WriteLine($"Rms = {metrics.Rms}");
            Console.WriteLine($"RSquared = {metrics.RSquared}");
        }
    }
}

TaxiTrip.cs:

using System;
using System.Collections.Generic;
using System.Text;
using Microsoft.ML.Runtime.Api;

namespace _01_TaxiFare
{
    public class TaxiTrip
    {
        [Column("0")]
        public string VendorId;

        [Column("1")]
        public string RateCode;

        [Column("2")]
        public float PassengerCount;

        [Column("3")]
        public float TripTime;

        [Column("4")]
        public float TripDistance;

        [Column("5")]
        public string PaymentType;

        [Column("6")]
        public float FareAmount;
    }

    public class TaxiTripFarePrediction
    {
        [ColumnName("Score")]
        public float FareAmount;
    }
}

调用:

using System;

namespace _01_TaxiFare
{
    class Program
    {
        static void Main(string[] args)
        {
            var prediction = TaxiFarePrediction.Predict(new TaxiTrip
            {
                VendorId = "VTS",
                RateCode = "1",
                PassengerCount = 1,
                TripDistance = 10.33f,
                PaymentType = "CSH",
                FareAmount = 0 // predict it. actual = 29.5
            }).Result;

            Console.WriteLine("Predicted fare: {0}, actual fare: 29.5", prediction.FareAmount);

            Console.ReadLine();
        }
    }
}


 

猜你喜欢

转载自blog.csdn.net/csharp25/article/details/84680443